From 935cfb65b246ff7e07138ec42cf6ccbe51c76b55 Mon Sep 17 00:00:00 2001 From: Afif <37773945+affifboudaoud@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:38:15 +0100 Subject: [PATCH 01/17] Machine Learning Integration for DaCe (Autodiff - ONNX - PyTorch) (#2164) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Pull Request: Machine Learning Integration for DaCe ## Overview This PR adds comprehensive machine learning capabilities to DaCe through three tightly integrated components: 1. **Automatic Differentiation (AD)** - Reverse-mode gradient computation for SDFGs 2. **ONNX Integration** - Import and execute neural network models 3. **PyTorch Integration** - Bidirectional interoperability with PyTorch's autograd system Together, these components enable DaCe to optimize and accelerate machine learning workloads, particularly neural network training and inference. ## High-Level Architecture ``` PyTorch Model ↓ ONNX Export ↓ DaCe SDFG (Forward) ↓ Automatic Differentiation ↓ DaCe SDFG (Backward) ↓ Compiled Code Generation ↓ PyTorch Operator (with Autograd) ``` ## Component 1: Automatic Differentiation (`dace/autodiff/`) ### Purpose Provides **reverse-mode automatic differentiation** for SDFGs, enabling gradient computation for any DaCe program. This is the foundation for neural network training and gradient-based optimization. ### Key Capabilities - **Full SDFG Support**: Differentiates maps, tasklets, nested SDFGs, loops, and library nodes - **Control Flow**: Handles loops (LoopRegion) and conditionals - **ONNX Operations**: 50+ backward implementations for ONNX operators - **Data Forwarding**: Flexible strategies (store vs. recompute) for memory/compute tradeoffs - **Extensible Registry**: Plugin-based system for adding backward rules ### Core Algorithm 1. **Forward Pass Execution**: Run original computation and identify required intermediates 2. **Backward Pass Generation**: Traverse computation graph in reverse, accumulating gradients 3. **Node Reversal**: Each forward node (Map, Tasklet, ONNXOp) has a registered backward implementation 4. **Gradient Accumulation**: Use write-conflict resolution (WCR) for multi-path gradients ### Key Files | File | Lines | Purpose | |------|-------|---------| | `backward_pass_generator.py` | ~800 | Core AD engine that orchestrates backward pass generation | | `implementations/onnx_ops.py` | ~2000 | Backward implementations for 50+ ONNX operations | | `implementations/dace_nodes.py` | ~600 | Backward rules for core SDFG elements (Tasklet, Map, etc.) | | `data_forwarding/manager.py` | ~300 | Store vs. recompute strategy coordination | --- ## Component 2: ONNX Integration (`dace/libraries/onnx/`) ### Purpose Enables **importing and executing ONNX neural network models** within DaCe. Converts ONNX graphs to optimized DaCe SDFGs for efficient execution on CPU/GPU. ### Key Capabilities - **Model Import**: Load ONNX models from files or protobuf objects - **100+ Operations**: Dynamically generated node classes for all ONNX ops - **Shape Inference**: Automatic symbolic and concrete shape computation - **Multi-Strategy Implementations**: Pure (correctness), optimized (performance), hardware-specific - **Type Safety**: Schema-based validation and type checking ### Core Architecture **Dynamic Node Generation**: - Registry system generates Python classes for all ONNX operations at import time - Each operation has schema, properties, connectors, and implementations - Example: `ONNXConv`, `ONNXMatMul`, `ONNXSoftmax` (100+ generated classes) **Implementation Strategies**: 1. **Pure Implementations** (`pure_implementations.py`): Reference implementations in Python/NumPy 2. **Optimized Implementations** (`img_op_implementations.py`): Hand-crafted SDFGs for performance 3. **Hardware-Specific**: Future GPU/FPGA specialized implementations **Import Pipeline**: ``` ONNX Model → Validation → Shape Inference → Simplification → SDFG Construction → Compilation ``` ### Key Files | File | Lines | Purpose | |------|-------|---------| | `onnx_importer.py` | 711 | Main entry point, orchestrates import pipeline | | `op_implementations/pure_implementations.py` | 3052 | Reference implementations for 40+ operations | | `nodes/onnx_op_registry.py` | 325 | Dynamic node class generation | | `schema.py` | 390 | Type system and validation | | `shape_inference/symbolic_shape_infer.py` | 1976 | Symbolic shape inference (Microsoft-sourced) | --- ## Component 3: PyTorch Integration (`dace/libraries/torch/`) ### Purpose Provides **bidirectional integration** between PyTorch and DaCe. Enables optimizing PyTorch models with DaCe while maintaining PyTorch's autograd compatibility. ### Key Capabilities - **Model Optimization**: Convert `torch.nn.Module` to optimized DaCe SDFGs - **Autograd Integration**: Backward pass generation integrates with PyTorch's autograd - **Dual Dispatch**: C++ extension (performance) or CTypes (flexibility) - **Zero-Copy Tensors**: DLPack protocol for efficient memory sharing - **Training Support**: Full forward + backward pass compilation ### Core Architecture **Integration Flow**: ``` PyTorch Model → ONNX Export → DaCe SDFG → Backward Generation → Compilation → PyTorch Operator ``` **Dispatcher Strategies**: 1. **C++ Extension** (`cpp_torch_extension.py`): Native PyTorch operator with autograd - High performance - 64 parameter limit - Slower compilation 2. **CTypes Module** (`ctypes_module.py`): Pure Python dispatcher - Unlimited parameters - Faster compilation - Slight overhead **Zero-Copy Memory Sharing**: - DLPack protocol enables PyTorch tensors to view DaCe memory without copying - Bidirectional: DaCe → PyTorch (outputs) and PyTorch → DaCe (inputs) ### Key Files | File | Lines | Purpose | |------|-------|---------| | `dispatchers/cpp_torch_extension.py` | 717 | C++ code generation for PyTorch operators | | `dispatchers/ctypes_module.py` | 230 | CTypes-based dispatcher | | `dlpack.py` | 199 | Zero-copy tensor sharing via DLPack | | `environments/pytorch_env.py` | 94 | CMake build configuration | --- ## How Components Work Together ### Example: Training a PyTorch Model with DaCe ```python import torch from dace.frontend.python import DaceModule # 1. Define PyTorch model model = MyNeuralNetwork() optimizer = torch.optim.Adam(model.parameters()) # 2. Wrap with DaCe (compiles on first call) dace_model = DaceModule(model, dummy_inputs, backward=True) # 3. Training loop (standard PyTorch code) for inputs, labels in dataloader: optimizer.zero_grad() outputs = dace_model(inputs) # DaCe-optimized forward pass loss = criterion(outputs, labels) loss.backward() # DaCe-optimized backward pass optimizer.step() ``` **What Happens Internally**: 1. **First Call**: PyTorch model → ONNX export → DaCe SDFG (via ONNX integration) 2. **Backward Generation**: Forward SDFG → Backward SDFG (via autodiff) 3. **Compilation**: Both SDFGs compiled to optimized code 4. **Dispatcher**: C++ extension or CTypes wrapper created 5. **Forward Pass**: DaCe executes optimized forward computation 6. **Backward Pass**: DaCe executes generated backward computation 7. **Gradient Return**: Gradients flow back to PyTorch optimizer ### Data Flow ``` PyTorch Tensor (input) ↓ Zero-copy (DLPack) DaCe Array ↓ Optimized SDFG Execution DaCe Array (output) ↓ Zero-copy (DLPack) PyTorch Tensor (output) ↓ loss.backward() PyTorch Tensor (grad_output) ↓ Zero-copy (DLPack) DaCe Array (backward pass input) ↓ Backward SDFG Execution DaCe Array (grad_input) ↓ Zero-copy (DLPack) PyTorch Tensor (grad_input) ``` --- ## Testing Strategy ### Test Organization ``` tests/ ├── autodiff/ # AD correctness tests │ ├── test_single_state.py # Basic AD operations │ └── torch/ # PyTorch integration tests │ ├── test_training.py # End-to-end training │ ├── test_bert_encoder_backward.py # BERT model │ └── test_llama_decoder_backward.py # LLaMA model │ ├── onnx/ # ONNX import tests │ ├── test_python_frontend.py # Basic operations │ ├── test_bert_subgraphs.py # Real model subgraphs │ └── test_input_outputs.py # I/O handling │ └── torch/ # PyTorch integration tests │ ├── test_lenet.py # Simple CNN │ ├── test_bert_encoder.py # Transformer encoder │ └── test_llama_decoder.py # Decoder architecture │ └── npbench/ # AD tests on NPBench kernels ``` ### Test Coverage | Component | Test Files | Coverage | |-----------|-----------|----------| | Autodiff Core | 15+ files | Tasklets, maps, loops, nested SDFGs | | ONNX Integration | 20+ files | Import, execution, type handling | | PyTorch Integration | 15+ files | Forward, backward, training loops | ### Running Tests ```bash # All basic tests (excluding hardware-specific) pytest -m "(autodiff or torch or onnx) and not long" tests/ # AD tests only pytest tests/autodiff/ # ONNX tests only pytest tests/onnx/ # PyTorch tests only pytest tests/torch/ ``` --- ## Known Limitations and Future Work ### Current Limitations 1. **Recompute Strategy**: Experimental, not production-ready 2. **Control Flow**: Conditionals are inlined into state machine (not reversed as ControlFlowRegions) 3. **Second-Order Gradients**: Not yest tested --- ## Documentation Each component has detailed design documentation: - [`dace/autodiff/autodiff.md`](dace/autodiff/autodiff.md) - Complete AD system design - [`dace/libraries/onnx/onnx.md`](dace/libraries/onnx/onnx.md) - ONNX integration architecture - [`dace/libraries/torch/torch.md`](dace/libraries/torch/torch.md) - PyTorch integration details These documents provide: - Detailed component descriptions - Algorithm explanations - Code walkthrough - Extension points - Implementation notes --- ## Impact on DaCe ### Code Additions | Component | Lines of Code | Files | |-----------|--------------|-------| | Autodiff | ~8,000 | 15+ files | | ONNX | ~7,000 | 20+ files | | PyTorch | ~1,500 | 10+ files | | **Total** | **~16,500** | **45+ files** | ### Dependencies New dependencies (already in `setup.py`): - `onnx` - ONNX model format - `onnxsim` - ONNX graph simplification - `torch` - PyTorch framework (optional) - `protobuf` - Protocol buffers (for ONNX) - `jax` - For gradient numerical validation tests -`transformers` - For testing the Pytorch/ONNX frontends - `efficientnet_pytorch`- For testing EfficientNet --- --------- Co-authored-by: Oliver Rausch --- .github/workflows/copilot-setup-steps.yml | 2 +- .github/workflows/fpga-ci.yml | 2 +- .github/workflows/general-ci.yml | 2 +- .github/workflows/gpu-ci.yml | 2 +- .github/workflows/ml-ci.yml | 62 + .gitignore | 5 + dace/__init__.py | 11 + dace/autodiff/__init__.py | 58 + dace/autodiff/analysis.py | 103 + dace/autodiff/autodiff.md | 821 +++++++ dace/autodiff/autodiff.py | 83 + dace/autodiff/backward_pass_generator.py | 2056 +++++++++++++++++ dace/autodiff/base_abc.py | 183 ++ dace/autodiff/data_forwarding/__init__.py | 19 + dace/autodiff/data_forwarding/manager.py | 388 ++++ dace/autodiff/data_forwarding/recompute.py | 298 +++ dace/autodiff/data_forwarding/store.py | 683 ++++++ dace/autodiff/implementations/__init__.py | 46 + dace/autodiff/implementations/dace_nodes.py | 487 ++++ .../implementations/dace_reduction_nodes.py | 307 +++ dace/autodiff/implementations/onnx_ops.py | 1045 +++++++++ dace/autodiff/implementations/pytorch_ops.py | 128 + dace/autodiff/library/__init__.py | 31 + dace/autodiff/library/library.py | 286 +++ dace/autodiff/library/torch_integration.py | 39 + dace/autodiff/torch.py | 124 + dace/autodiff/utils.py | 910 ++++++++ dace/codegen/codegen.py | 2 +- dace/codegen/common.py | 11 + dace/codegen/targets/cpp.py | 3 +- dace/codegen/targets/cpu.py | 1 - dace/frontend/ml/__init__.py | 13 + dace/frontend/ml/onnx/__init__.py | 5 + dace/frontend/ml/onnx/importer.py | 794 +++++++ dace/frontend/{ => ml}/tensorflow/__init__.py | 0 .../{ => ml}/tensorflow/tensorflow.py | 4 +- .../tensorflow/transformations/__init__.py | 0 .../transformations/redundant_array.py | 0 dace/frontend/{ => ml}/tensorflow/winograd.py | 0 dace/frontend/ml/torch/__init__.py | 6 + dace/frontend/ml/torch/interface.py | 89 + dace/frontend/ml/torch/module.py | 581 +++++ dace/frontend/python/interface.py | 2 +- .../python/replacements/torch_autodiff.py | 163 ++ dace/libraries/blas/blas_helpers.py | 2 +- dace/libraries/onnx/__init__.py | 61 + dace/libraries/onnx/converters.py | 247 ++ .../onnx/forward_implementation_abc.py | 105 + dace/libraries/onnx/nodes/__init__.py | 2 + dace/libraries/onnx/nodes/node_utils.py | 90 + dace/libraries/onnx/nodes/onnx_op.py | 295 +++ dace/libraries/onnx/nodes/onnx_op_registry.py | 351 +++ dace/libraries/onnx/onnx.md | 993 ++++++++ .../onnx/op_implementations/__init__.py | 11 + .../onnx/op_implementations/array_ops.py | 681 ++++++ .../onnx/op_implementations/common.py | 11 + .../criteria_implementations.py | 90 + .../op_implementations/elementwise_ops.py | 212 ++ .../onnx/op_implementations/image_ops.py | 443 ++++ .../img_op_implementations.py | 563 +++++ .../onnx/op_implementations/linalg_ops.py | 359 +++ .../op_implementations/normalization_ops.py | 281 +++ .../onnx/op_implementations/reduction_ops.py | 304 +++ .../onnx/op_implementations/utils.py | 223 ++ dace/libraries/onnx/schema.py | 333 +++ dace/libraries/torch/__init__.py | 23 + dace/libraries/torch/dispatchers/__init__.py | 22 + dace/libraries/torch/dispatchers/common.py | 112 + .../torch/dispatchers/cpp_torch_extension.py | 699 ++++++ .../torch/dispatchers/ctypes_module.py | 222 ++ dace/libraries/torch/dlpack.py | 188 ++ dace/libraries/torch/environments/__init__.py | 2 + .../torch/environments/pytorch_env.py | 100 + dace/libraries/torch/torch.md | 1254 ++++++++++ dace/ml/__init__.py | 16 + dace/sdfg/sdfg.py | 2 +- dace/sdfg/utils.py | 86 + dace/transformation/auto/auto_optimize.py | 2 +- dace/transformation/onnx/__init__.py | 10 + dace/transformation/onnx/constant_folding.py | 158 ++ dace/transformation/onnx/optimize.py | 65 + .../onnx/parameter_to_transient.py | 83 + dace/transformation/onnx/replacement.py | 159 ++ .../transformation/passes/scalar_to_symbol.py | 11 +- .../simplification/control_flow_raising.py | 4 +- .../transformation/subgraph/stencil_tiling.py | 5 + pytest.ini | 4 + setup.py | 15 +- tests/.gitignore | 3 + tests/autodiff/test_multi_state.py | 304 +++ tests/autodiff/test_nested.py | 174 ++ tests/autodiff/test_single_state.py | 635 +++++ .../test_dont_compute_input_grads.py | 68 + tests/autodiff/torch_backward/test_dropout.py | 69 + .../test_extremal_reduction_backward.py | 160 ++ .../test_full_training_graph.py | 227 ++ .../test_llama_decoder_backward.py | 107 + .../test_llama_for_causalLM_backward.py | 105 + .../torch_backward/test_multi_output_ad.py | 64 + tests/autodiff/torch_backward/test_pytorch.py | 305 +++ .../autodiff/torch_backward/test_training.py | 124 + tests/conftest.py | 11 + .../npbench/deep_learning/conv2d_bias_test.py | 91 + tests/npbench/deep_learning/lenet_test.py | 135 +- tests/npbench/deep_learning/mlp_test.py | 74 + tests/npbench/deep_learning/resnet_test.py | 114 + tests/npbench/deep_learning/softmax_test.py | 45 + tests/npbench/misc/cavity_flow_test.py | 121 + tests/npbench/misc/compute_test.py | 1 + tests/npbench/misc/go_fast_test.py | 53 + tests/npbench/polybench/adi_test.py | 126 +- tests/npbench/polybench/atax_test.py | 43 + tests/npbench/polybench/bicg_test.py | 48 + tests/npbench/polybench/cholesky_test.py | 70 + tests/npbench/polybench/correlation_test.py | 61 + tests/npbench/polybench/covariance_test.py | 56 + tests/npbench/polybench/deriche_test.py | 101 + tests/npbench/polybench/doitgen_test.py | 51 + tests/npbench/polybench/durbin_test.py | 83 + tests/npbench/polybench/fdtd_2d_test.py | 61 + tests/npbench/polybench/gemm_npbench_test.py | 47 +- tests/npbench/polybench/gemver_test.py | 71 +- tests/npbench/polybench/gesummv_test.py | 47 +- tests/npbench/polybench/gramschmidt_test.py | 76 + tests/npbench/polybench/heat_3d_test.py | 66 + tests/npbench/polybench/jacobi_1d_test.py | 52 +- tests/npbench/polybench/jacobi_2d_test.py | 54 + tests/npbench/polybench/k2mm_test.py | 59 +- tests/npbench/polybench/k3mm_test.py | 47 +- tests/npbench/polybench/lu_test.py | 77 + tests/npbench/polybench/ludcmp_test.py | 108 + tests/npbench/polybench/mvt_test.py | 50 +- tests/npbench/polybench/seidel_2d_test.py | 75 + tests/npbench/polybench/symm_test.py | 71 +- tests/npbench/polybench/syr2k_test.py | 95 +- tests/npbench/polybench/syrk_test.py | 82 + tests/npbench/polybench/trisolv_test.py | 58 +- tests/npbench/polybench/trmm_test.py | 73 +- tests/npbench/weather_stencils/hdiff_test.py | 75 + tests/npbench/weather_stencils/vadv_test.py | 147 +- .../pure_expansions/test_conv_expansion.py | 61 + .../pure_expansions/test_expansion_utils.py | 42 + tests/onnx/pure_expansions/test_expansions.py | 561 +++++ tests/onnx/test_bert_subgraphs.py | 106 + tests/onnx/test_input_outputs.py | 229 ++ tests/onnx/test_models/test_bert.py | 92 + tests/onnx/test_name_shadowing.py | 37 + tests/onnx/test_onnx_return_scalars.py | 61 + tests/onnx/test_python_frontend.py | 31 + tests/onnx/test_shared_input_output.py | 111 + tests/onnx/test_variadic.py | 53 + tests/tensorflow/callback_test.py | 2 +- tests/tensorflow/compile_test.py | 2 +- tests/tensorflow/conv_test.py | 2 +- tests/tensorflow/fbn_test.py | 2 +- tests/tensorflow/ops_test.py | 8 +- tests/tensorflow/pool_test.py | 2 +- tests/tensorflow/simple_test.py | 2 +- tests/torch_forward/test_attn.py | 39 + tests/torch_forward/test_conv2d.py | 55 + tests/torch_forward/test_cpp_extension.py | 120 + tests/torch_forward/test_debug_transients.py | 36 + tests/torch_forward/test_dlpack.py | 26 + .../torch_forward/test_efficientnet_block.py | 114 + .../test_img_op_implementations.py | 95 + tests/torch_forward/test_lenet.py | 60 + .../torch_forward/test_module_dace_program.py | 63 + tests/torch_forward/test_multi_output.py | 44 + tests/torch_forward/test_reshape.py | 38 + tests/utils.py | 60 + 170 files changed, 25973 insertions(+), 55 deletions(-) create mode 100644 .github/workflows/ml-ci.yml create mode 100644 dace/autodiff/__init__.py create mode 100644 dace/autodiff/analysis.py create mode 100644 dace/autodiff/autodiff.md create mode 100644 dace/autodiff/autodiff.py create mode 100644 dace/autodiff/backward_pass_generator.py create mode 100644 dace/autodiff/base_abc.py create mode 100644 dace/autodiff/data_forwarding/__init__.py create mode 100644 dace/autodiff/data_forwarding/manager.py create mode 100644 dace/autodiff/data_forwarding/recompute.py create mode 100644 dace/autodiff/data_forwarding/store.py create mode 100644 dace/autodiff/implementations/__init__.py create mode 100644 dace/autodiff/implementations/dace_nodes.py create mode 100644 dace/autodiff/implementations/dace_reduction_nodes.py create mode 100644 dace/autodiff/implementations/onnx_ops.py create mode 100644 dace/autodiff/implementations/pytorch_ops.py create mode 100644 dace/autodiff/library/__init__.py create mode 100644 dace/autodiff/library/library.py create mode 100644 dace/autodiff/library/torch_integration.py create mode 100644 dace/autodiff/torch.py create mode 100644 dace/autodiff/utils.py create mode 100644 dace/frontend/ml/__init__.py create mode 100644 dace/frontend/ml/onnx/__init__.py create mode 100644 dace/frontend/ml/onnx/importer.py rename dace/frontend/{ => ml}/tensorflow/__init__.py (100%) rename dace/frontend/{ => ml}/tensorflow/tensorflow.py (99%) rename dace/frontend/{ => ml}/tensorflow/transformations/__init__.py (100%) rename dace/frontend/{ => ml}/tensorflow/transformations/redundant_array.py (100%) rename dace/frontend/{ => ml}/tensorflow/winograd.py (100%) create mode 100644 dace/frontend/ml/torch/__init__.py create mode 100644 dace/frontend/ml/torch/interface.py create mode 100644 dace/frontend/ml/torch/module.py create mode 100644 dace/frontend/python/replacements/torch_autodiff.py create mode 100644 dace/libraries/onnx/__init__.py create mode 100644 dace/libraries/onnx/converters.py create mode 100644 dace/libraries/onnx/forward_implementation_abc.py create mode 100644 dace/libraries/onnx/nodes/__init__.py create mode 100644 dace/libraries/onnx/nodes/node_utils.py create mode 100644 dace/libraries/onnx/nodes/onnx_op.py create mode 100644 dace/libraries/onnx/nodes/onnx_op_registry.py create mode 100644 dace/libraries/onnx/onnx.md create mode 100644 dace/libraries/onnx/op_implementations/__init__.py create mode 100644 dace/libraries/onnx/op_implementations/array_ops.py create mode 100644 dace/libraries/onnx/op_implementations/common.py create mode 100644 dace/libraries/onnx/op_implementations/criteria_implementations.py create mode 100644 dace/libraries/onnx/op_implementations/elementwise_ops.py create mode 100644 dace/libraries/onnx/op_implementations/image_ops.py create mode 100644 dace/libraries/onnx/op_implementations/img_op_implementations.py create mode 100644 dace/libraries/onnx/op_implementations/linalg_ops.py create mode 100644 dace/libraries/onnx/op_implementations/normalization_ops.py create mode 100644 dace/libraries/onnx/op_implementations/reduction_ops.py create mode 100644 dace/libraries/onnx/op_implementations/utils.py create mode 100644 dace/libraries/onnx/schema.py create mode 100644 dace/libraries/torch/__init__.py create mode 100644 dace/libraries/torch/dispatchers/__init__.py create mode 100644 dace/libraries/torch/dispatchers/common.py create mode 100644 dace/libraries/torch/dispatchers/cpp_torch_extension.py create mode 100644 dace/libraries/torch/dispatchers/ctypes_module.py create mode 100644 dace/libraries/torch/dlpack.py create mode 100644 dace/libraries/torch/environments/__init__.py create mode 100644 dace/libraries/torch/environments/pytorch_env.py create mode 100644 dace/libraries/torch/torch.md create mode 100644 dace/ml/__init__.py create mode 100644 dace/transformation/onnx/__init__.py create mode 100644 dace/transformation/onnx/constant_folding.py create mode 100644 dace/transformation/onnx/optimize.py create mode 100644 dace/transformation/onnx/parameter_to_transient.py create mode 100644 dace/transformation/onnx/replacement.py create mode 100644 tests/autodiff/test_multi_state.py create mode 100644 tests/autodiff/test_nested.py create mode 100644 tests/autodiff/test_single_state.py create mode 100644 tests/autodiff/torch_backward/test_dont_compute_input_grads.py create mode 100644 tests/autodiff/torch_backward/test_dropout.py create mode 100644 tests/autodiff/torch_backward/test_extremal_reduction_backward.py create mode 100644 tests/autodiff/torch_backward/test_full_training_graph.py create mode 100644 tests/autodiff/torch_backward/test_llama_decoder_backward.py create mode 100644 tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py create mode 100644 tests/autodiff/torch_backward/test_multi_output_ad.py create mode 100644 tests/autodiff/torch_backward/test_pytorch.py create mode 100644 tests/autodiff/torch_backward/test_training.py create mode 100644 tests/onnx/pure_expansions/test_conv_expansion.py create mode 100644 tests/onnx/pure_expansions/test_expansion_utils.py create mode 100644 tests/onnx/pure_expansions/test_expansions.py create mode 100644 tests/onnx/test_bert_subgraphs.py create mode 100644 tests/onnx/test_input_outputs.py create mode 100644 tests/onnx/test_models/test_bert.py create mode 100644 tests/onnx/test_name_shadowing.py create mode 100644 tests/onnx/test_onnx_return_scalars.py create mode 100644 tests/onnx/test_python_frontend.py create mode 100644 tests/onnx/test_shared_input_output.py create mode 100644 tests/onnx/test_variadic.py create mode 100644 tests/torch_forward/test_attn.py create mode 100644 tests/torch_forward/test_conv2d.py create mode 100644 tests/torch_forward/test_cpp_extension.py create mode 100644 tests/torch_forward/test_debug_transients.py create mode 100644 tests/torch_forward/test_dlpack.py create mode 100644 tests/torch_forward/test_efficientnet_block.py create mode 100644 tests/torch_forward/test_img_op_implementations.py create mode 100644 tests/torch_forward/test_lenet.py create mode 100644 tests/torch_forward/test_module_dace_program.py create mode 100644 tests/torch_forward/test_multi_output.py create mode 100644 tests/torch_forward/test_reshape.py create mode 100644 tests/utils.py diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index fe60e4867c..c3d1a4088e 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -35,6 +35,6 @@ jobs: - name: Install DaCe in development mode run: | - python -m pip install --editable ".[testing,linting]" + python -m pip install --editable ".[testing,linting,ml]" pre-commit install pre-commit run diff --git a/.github/workflows/fpga-ci.yml b/.github/workflows/fpga-ci.yml index 926d4c69e9..21ad7c1ac2 100644 --- a/.github/workflows/fpga-ci.yml +++ b/.github/workflows/fpga-ci.yml @@ -32,7 +32,7 @@ jobs: python -m pip install --upgrade pip pip install pytest-xdist flake8 coverage click pip uninstall -y dace - pip install -e ".[testing]" + pip install -e ".[testing,ml]" curl -Os https://uploader.codecov.io/latest/linux/codecov chmod +x codecov diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index 1d9dc3fa79..59e0aae179 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -59,7 +59,7 @@ jobs: else export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} fi - pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long and not sequential" + pytest -n auto --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -m "not gpu and not autodiff and not torch and not onnx and not verilator and not tensorflow and not mkl and not sve and not papi and not mlir and not lapack and not fpga and not mpi and not rtl_hardware and not scalapack and not datainstrument and not long and not sequential" ./codecov - name: Test OpenBLAS LAPACK diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index a7a28d7d91..68cfabfa16 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -37,7 +37,7 @@ jobs: pip install mpi4py pip install cupy pip uninstall -y dace - pip install -e ".[testing]" + pip install -e ".[testing,ml]" curl -Os https://uploader.codecov.io/latest/linux/codecov chmod +x codecov diff --git a/.github/workflows/ml-ci.yml b/.github/workflows/ml-ci.yml new file mode 100644 index 0000000000..b36890d562 --- /dev/null +++ b/.github/workflows/ml-ci.yml @@ -0,0 +1,62 @@ +name: Machine Learning and Autodiff Tests + +on: + push: + branches: [ main, ci-fix ] + pull_request: + branches: [ main, ci-fix ] + merge_group: + branches: [ main, ci-fix ] + +concurrency: + group: ${{github.workflow}}-${{github.ref}} + cancel-in-progress: true + +jobs: + test: + if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.13'] + simplify: [0,1,autoopt] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y libyaml-dev cmake + sudo apt-get install -y libblas-dev libopenblas-dev liblapacke-dev + python -m pip install --upgrade pip + pip install flake8 pytest-xdist coverage + pip install -e ".[ml-testing,ml]" + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + + - name: Test with pytest + run: | + export NOSTATUSBAR=1 + export DACE_testing_serialization=1 + export DACE_testing_deserialize_exception=1 + export DACE_cache=unique + if [ "${{ matrix.simplify }}" = "autoopt" ]; then + export DACE_optimizer_automatic_simplification=1 + export DACE_optimizer_autooptimize=1 + echo "Auto-optimization heuristics" + else + export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} + fi + pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -v -m "(torch or onnx or autodiff) and not gpu" + ./codecov + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true diff --git a/.gitignore b/.gitignore index 5a54e2df44..3662b78b1f 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,8 @@ _build/ # Ignoring the test junk _all_tests/ + + +# Ignore downloaded ONNX models +/*.onnx +/*.bin diff --git a/dace/__init__.py b/dace/__init__.py index 823abb9111..98a44bd217 100644 --- a/dace/__init__.py +++ b/dace/__init__.py @@ -35,6 +35,17 @@ sys.path.insert(0, __external_transformations_path__) +# Lazy loading for ml module to avoid eager TensorFlow/PyTorch imports +def __getattr__(name): + if name == 'ml': + import importlib + ml_module = importlib.import_module('.ml', package='dace') + # Cache the module to avoid re-importing + globals()['ml'] = ml_module + return ml_module + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + # Hack that enables using @dace as a decorator # See https://stackoverflow.com/a/48100440/6489142 class DaceModule(sys.modules[__name__].__class__): diff --git a/dace/autodiff/__init__.py b/dace/autodiff/__init__.py new file mode 100644 index 0000000000..2e0e9bf746 --- /dev/null +++ b/dace/autodiff/__init__.py @@ -0,0 +1,58 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe Automatic Differentiation (AD) System. + +This module provides reverse-mode automatic differentiation for DaCe programs, +enabling automatic computation of gradients for optimized numerical kernels. + +Main Components +--------------- +- **add_backward_pass**: Main entry point for adding backward pass to an SDFG +- **BackwardPassGenerator**: Core algorithm for generating backward passes +- **BackwardImplementation**: ABC for implementing operation-specific backward rules +- **BackwardContext**: Context information for backward pass generation +- **BackwardResult**: Result of backward pass generation with forward/backward SDFGs +- **AutoDiffException**: Base exception for autodiff errors + +Key Features +------------ +- Support for control flow (loops, conditionals) +- Data forwarding strategies (store vs recompute tradeoffs) +- Extensible backward implementations for library nodes +- Integration with PyTorch autograd +- Automatic memory management for intermediate values + + +""" + +from .base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException +from .backward_pass_generator import BackwardPassGenerator +from .autodiff import add_backward_pass + +try: + from .torch import make_backward_function + TORCH_INTEGRATION_AVAILABLE = True +except ImportError: + make_backward_function = None + TORCH_INTEGRATION_AVAILABLE = False + +import sys +from . import library + +__all__ = [ + # Main API + "add_backward_pass", + # Core classes + "BackwardPassGenerator", + "BackwardContext", + "BackwardResult", + # Extension points + "BackwardImplementation", + # Exceptions + "AutoDiffException", + # Submodules + "library", +] + +if TORCH_INTEGRATION_AVAILABLE: + __all__.append("make_backward_function") diff --git a/dace/autodiff/analysis.py b/dace/autodiff/analysis.py new file mode 100644 index 0000000000..224f0db9f8 --- /dev/null +++ b/dace/autodiff/analysis.py @@ -0,0 +1,103 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Analysis helpers for autodiff +""" +from typing import Dict, Set, Tuple, Optional +import collections + +import networkx as nx + +from dace.sdfg import SDFG, SDFGState, nodes, utils as sdfg_utils +from dace.transformation.passes import analysis +from dace.sdfg.state import FunctionCallRegion + +AccessSets = Dict[SDFGState, Tuple[Set[str], Set[str]]] + + +def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]: + """ + Analyze read dependencies of arrays in the SDFG. + + :param sdfg: SDFG to analyze + :return: A dictionary mapping array names to a list of read dependencies. + """ + + # FIXME can be made more efficient + dependencies = nx.DiGraph() + for sdfg_node in sdfg.nodes(): + if isinstance(sdfg_node, SDFGState): + for node in sdfg_node.data_nodes(): + for edge in sdfg_node.edge_bfs(node, reverse=True): + dependencies.add_edge(node.data, edge.data.data) + elif isinstance(sdfg_node, FunctionCallRegion): + for state in sdfg_node.nodes(): + assert isinstance(state, SDFGState) + for node in state.data_nodes(): + for edge in state.edge_bfs(node, reverse=True): + dependencies.add_edge(node.data, edge.data.data) + + dependencies = nx.transitive_closure(dependencies) + result = {} + for array in dependencies: + result[array] = {nbr for nbr in dependencies.neighbors(array)} + return result + + +def inverse_reachability(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: + + reachability = analysis.StateReachability().apply_pass(sdfg, {}) + inverse_reachability = collections.defaultdict(set) + # iterate over cfg_ids + for cfg_id in reachability.keys(): + for pred, successors in reachability[cfg_id].items(): + for successor in successors: + inverse_reachability[successor].add(pred) + + return inverse_reachability + + +def is_previously_written(sdfg: SDFG, + state: SDFGState, + node: nodes.Node, + array_name: str, + access_sets: Optional[AccessSets] = None) -> bool: + """ + Determine whether the given array name was written before the current node. + + :param sdfg: the sdfg containing the node + :param state: the state containing the node + :param node: the node to check + :param array_name: the array name to check + :return: True if the array was written before the node, False otherwise. + """ + + if access_sets is None: + access_sets = analysis.AccessSets().apply_pass(sdfg, {}) + + reachable = inverse_reachability(sdfg) + + # Check the current state + for subgraph in sdfg_utils.concurrent_subgraphs(state): + if node in subgraph.nodes(): + # Get all the access nodes in the subgraph to the same data + for other_node in subgraph.data_nodes(): + if other_node != node and other_node.data == array_name: + # Check if this is a write node + for in_edge in subgraph.in_edges(other_node): + if in_edge.data.data == array_name: + # Check if there's a path to our node, + # since we only care about writes that happen before the current node + if nx.has_path(state.nx, other_node, node): + return True + else: + # This is not our current subgraph, check the write states + _, write_set = subgraph.read_and_write_sets() + if array_name in write_set: + return True + + # Check other states + for predecessor in reachable[state]: + _, write_set = access_sets[predecessor] + if array_name in write_set: + return True + return False diff --git a/dace/autodiff/autodiff.md b/dace/autodiff/autodiff.md new file mode 100644 index 0000000000..cdec31941c --- /dev/null +++ b/dace/autodiff/autodiff.md @@ -0,0 +1,821 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe Automatic Differentiation (AD) System - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Directory Structure](#2-directory-structure) +3. [Core Components](#3-core-components) +4. [Data Forwarding System](#4-data-forwarding-system) +5. [Backward Implementations](#5-backward-implementations) +6. [Library Integration](#6-library-integration) +7. [PyTorch Integration](#7-pytorch-integration) +8. [Gradient Accumulation and Clearing](#8-gradient-accumulation-and-clearing) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe Automatic Differentiation (AD) module provides **reverse-mode automatic differentiation** for Stateful DataFlow Graphs (SDFGs). It enables automatic computation of gradients for optimized numerical kernels, making it possible to differentiate DaCe programs for machine learning, optimization, and scientific computing applications. + +### 1.2 Reverse-Mode AD Fundamentals + +Reverse-mode automatic differentiation (also known as backpropagation) computes gradients by: + +1. **Forward Pass**: Execute the original computation and record intermediate values +2. **Backward Pass**: Traverse the computation graph in reverse, accumulating gradients using the chain rule + +For a function `f: ℝⁿ → ℝᵐ`, reverse-mode AD efficiently computes the gradient when `m << n` (common in ML where loss is scalar). + +**Example**: For a composite function `y = f(g(h(x)))`: +- **Forward pass**: Compute and store intermediate values: + - `a = h(x)` + - `b = g(a)` + - `y = f(b)` +- **Backward pass**: Apply chain rule in reverse order. Given seed gradient `∂L/∂y`: + - Compute `∂L/∂b = ∂L/∂y · (∂f/∂b)` where `(∂f/∂b)` is evaluated at the stored value of `b` + - Compute `∂L/∂a = ∂L/∂b · (∂g/∂a)` where `(∂g/∂a)` is evaluated at the stored value of `a` + - Compute `∂L/∂x = ∂L/∂a · (∂h/∂x)` where `(∂h/∂x)` is evaluated at the input value `x` + +### 1.3 Key Features + +- **Control Flow Support**: Handles loops (`LoopRegion`) and conditionals +- **Data Forwarding Strategies**: Flexible tradeoff between memory (store intermediates) and computation (recompute on demand) +- **Extensible Backward Implementations**: Registry-based system for adding backward rules for new operations +- **ONNX Integration**: Differentiate ONNX neural network models imported into DaCe +- **PyTorch Compatibility**: Integration with PyTorch's autograd system via `torch.autograd.Function` +- **Library Node Support**: Backward implementations for DaCe standard library (BLAS, reductions, etc.) +- **Nested SDFG Differentiation**: Recursive backward pass generation for nested SDFGs + +### 1.4 Use Cases + +1. **Machine Learning Training**: Compute gradients for neural network parameters +2. **Sensitivity Analysis**: Determine how outputs change with respect to inputs +3. **Optimization**: Gradient-based optimization of physical simulations +4. **Inverse Problems**: Solve inverse problems by differentiating forward models +5. **Scientific Computing**: Adjoint methods for PDEs and large-scale simulations + +### 1.5 Component Interaction Flow + +``` +Input: Forward SDFG + Output Arrays + Input Arrays + ▼ +1. add_backward_pass() - Entry point + • Validate SDFG + • Simplify (optional) + • Inline control flow (conditionals) + ▼ +2. BackwardPassGenerator.__init__() + • Convert AccessNodes/strings to internal format + • Initialize mappings (reverse_map, array_grad_map, etc.) + • Set data forwarding strategy + ▼ +3. BackwardPassGenerator.backward() + • Reverse states in topological order + • For each state: + a. Reverse nodes (AccessNode, Tasklet, Map, etc.) + b. Find backward implementation via registry + c. Call implementation.backward() + d. Connect gradients with WCR + ▼ +4. DataForwardingManager.forward_data_to_backward_pass() + • Identify intermediates needed in backward pass + • Check if overwritten + • Apply strategy (store or recompute) + ▼ +5. Simplify and validate (optional) + ▼ +Output: Backward SDFG with gradients computed +``` + +--- + +## 2. Directory Structure + +### 2.1 File Organization + +``` +dace/autodiff/ +├── __init__.py # Main API exports +│ └── Exports: add_backward_pass, BackwardPassGenerator, +│ BackwardImplementation, AutoDiffException, etc. +│ +├── autodiff.py # Entry point +│ └── add_backward_pass() - High-level API +│ +├── base_abc.py # Abstract base classes +│ ├── BackwardImplementation (ABC) +│ ├── BackwardContext (dataclass) +│ ├── BackwardResult (dataclass) +│ ├── AutoDiffException +│ └── find_backward_implementation() +| +├── backward_pass_generator.py # Core AD engine +│ └── BackwardPassGenerator class - Main differentiation algorithm +│ +├── analysis.py # SDFG analysis +│ ├── dependency_analysis() +│ ├── inverse_reachability() +│ └── is_previously_written() +│ +├── utils.py # Utility functions +│ ├── Descriptor management +│ ├── Symbolic differentiation +│ ├── Graph traversal +│ └── Loop analysis +│ +├── torch.py # PyTorch integration +│ └── make_backward_function() - Convert ONNX to PyTorch differentiable +│ +├── data_forwarding/ # Store or recompute strategies +│ ├── __init__.py # Package exports +│ ├── manager.py # Strategy coordinator +│ │ └── DataForwardingManager +│ ├── store.py # Store strategy +│ │ └── resolve_overwrite_with_store() +│ └── recompute.py # Recompute strategy +│ └── resolve_overwrite_with_recomputation() +│ get_recomputation_nsdfg() +│ +├── implementations/ # Backward rules for node types +│ ├── __init__.py # Package exports (46 lines) +│ ├── dace_nodes.py # Pure SDFG elements (487 lines) +│ │ └── DaceNodeBackwardImplementations +│ │ ├── _reverse_AccessNode() +│ │ ├── _reverse_Tasklet() +│ │ ├── _reverse_MapEntry() +│ │ ├── _reverse_MapExit() +│ │ └── _reverse_NestedSDFG() +│ ├── dace_reduction_nodes.py # Reduction operations (307 lines) +│ │ ├── ReverseReduce +│ │ └── ... (reduction backward implementations) +│ ├── onnx_ops.py # ONNX operations (1045 lines) +│ │ ├── ONNXConvBackward +│ │ ├── ONNXMatMulBackward +│ │ └── ... (50+ ONNX ops) +│ └── pytorch_ops.py # PyTorch operations (128 lines) +│ └── Depthwise convolution backward pass +│ +└── library/ # Library integrations + ├── __init__.py # Package exports (31 lines) + ├── library.py # BackwardPass node (286 lines) + │ ├── ParameterArray (data descriptor) + │ ├── BackwardPass (LibraryNode) + │ └── ExpandBackwardPass (expansion) + └── torch_integration.py # PyTorch hooks (39 lines) +``` + + +## 3. Core Components + +### 3.1 Entry Point: `autodiff.py` + +**Location**: [autodiff.py](autodiff.py) + +The main entry point for users to add backward passes to SDFGs. + + +#### 3.1.1 Workflow + +``` +┌─────────────────────┐ +│ 1. Validate SDFG │ +└──────────┬──────────┘ + ▼ + ┌───────────────┐ + │ 2. Simplify │ + └──────┬────────┘ + ▼ +┌─────────────────────────────────┐ +│ 3. Inline Control Flow │ +│ (conditionals, not loops) │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────────────────────┐ +│ 4. Create Backward SDFG │ +│ (if separate_sdfgs flag is True) │ +└──────────┬──────────────────────────┘ + ▼ +┌─────────────────────────────────┐ +│ 5. Initialize BackwardPass- │ +│ Generator │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────────────────┐ +│ 6. generator.backward() │ +│ (main differentiation) │ +└──────────┬──────────────────────┘ + ▼ +┌─────────────────────┐ +│ 7. Validate SDFG │ +└──────────┬──────────┘ + ▼ + ┌───────────────┐ + │ 8. Simplify │ + └──────┬────────┘ + ▼ +┌─────────────────────┐ +│ 9. Return SDFG │ +└─────────────────────┘ +``` + +#### 3.1.2 Key Constraints + +- **Supported Nodes**: + - Maps, AccessNodes, Tasklets, LoopRegions, ControlFlowRegions (inlined into state machine) + - Reductions (Sum, Min, Max) + - ONNXOps (with registered backward implementations) + - NestedSDFGs + +--- + +### 3.2 BackwardPassGenerator: The Core AD Engine + +**Location**: [backward_pass_generator.py](backward_pass_generator.py) + +The `BackwardPassGenerator` class is the core of the AD system. It orchestrates the entire backward pass generation process. + +#### 3.2.1 Key Data Structures + +The generator maintains several mappings and data structures: + +- **Configuration**: + - `sdfg`: Forward SDFG + - `backward_sdfg`: Backward SDFG (can be same or separate) + - `given_gradients_data`: Output arrays (seed gradients provided) + - `required_gradients_data`: Input arrays (gradients to compute) + - `data_forwarding_strategy`: "store_all", "recompute_all", "user_defined" + +- **Generated Mappings**: + - `reverse_map: Dict[Node, Node]`: Forward node → backward node + - `reversed_states_map: Dict[SDFGState, SDFGState]`: Forward state → backward state + - `array_grad_map: Dict[str, str]`: Array name → gradient array name + - `result_map: Dict[Node, BackwardResult]`: Forward node → BackwardResult + +- **Analysis Results**: + - `read_only_arrays`: Arrays never written to + - `backward_grad_arrays`: Gradient array descriptors + - `backward_input_arrays`: Forward values needed in backward pass + - `data_to_forward`: List of data to forward from forward to backward + +#### 3.2.2 Main Algorithm: `backward()` + +**Steps**: + +1. **Initialize gradient arrays** for all required outputs +2. **Compute state order** (topological sort of SDFG states) +3. **Extract the Critical Computation Subgraph (CCS) of each state** +4. **Reverse the CCS of states** in reverse topological order: + - Create backward state + - Reverse nodes within CCS of the state + - Connect gradients between reversed nodes +5. **Reverse loop regions** by generating loop regions in the backward pass +6. **Handle data forwarding** (store or recompute intermediates) +7. **Create interstate edges** to reverse control flow and connect all reversed components +8. **Return** backward result with gradient mappings + +#### 3.2.3 State Reversal + +For each forward state, the generator: + +1. Creates a corresponding backward state +2. For each node in the CCS of the state: + - Finds appropriate backward implementation from registry + - Determines given/required gradients + - Calls `implementation.backward()` + - Stores mapping and result +3. Connects gradients between reversed nodes + +--- + +### 3.3 Abstract Base Classes: `base_abc.py` + +**Location**: [base_abc.py](base_abc.py) + +#### 3.3.1 BackwardImplementation (ABC) + +The abstract base class for all backward implementations. + +```python```python +@dace.registry.make_registry +class BackwardImplementation(abc.ABC): + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: SDFGState, + sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the node.""" + return True + + @staticmethod + @abc.abstractmethod + def backward( + forward_node: nd.Node, + context: BackwardContext, + given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]] + ) -> Tuple[nd.Node, BackwardResult]: + """Generate the backward pass for this node.""" + ... +``` + +**Registration Example**: + +```python + +# For ONNX operations +@dace.registry.autoregister_params(op="MatMul", name="pure") +class MatMulBackward(BackwardImplementation): + ... +``` + +#### 3.3.2 BackwardContext (Dataclass) + +Contains all context information needed by backward implementations: + +```python +@dataclasses.dataclass +class BackwardContext: + forward_sdfg: SDFG # The forward SDFG + backward_sdfg: SDFG # The backward SDFG + backward_generator: BackwardPassGenerator # The generator (for utilities) +``` + +#### 3.3.3 BackwardResult (Dataclass) + +Returns information about the generated backward node: + +```python +@dataclasses.dataclass +class BackwardResult: + """Result of differentiating a node.""" + + # Mapping from forward input connector → gradient connector name + required_grad_names: Dict[Optional[str], Optional[str]] + + # Mapping from forward output connector → gradient connector name + given_grad_names: Dict[Optional[str], Optional[str]] + + # Which gradients should be zero-initialized + zero_init: Dict[Optional[str], Optional[bool]] +``` + +#### 3.3.4 find_backward_implementation() + +Looks up the registered backward implementation for a node by: + +1. Querying `BackwardImplementation.extensions()` registry +2. Filtering by `node_type` (for DaCe nodes) or `op` (for ONNX) +3. Checking `backward_can_be_applied()` for each candidate +4. Returning first valid implementation + +--- + +### 3.4 Analysis Utilities: `analysis.py` + +**Location**: [analysis.py](analysis.py) + +Provides SDFG analysis functions used by the AD engine: + +#### 3.4.1 dependency_analysis() + +Computes transitive read dependencies for each array. For example, if `C = A + B`, then `dependencies["C"] = {"A", "B"}`. Uses graph traversal and transitive closure to build a complete dependency map. + +#### 3.4.2 inverse_reachability() + +For each state, computes the set of predecessor states that can reach it. Uses DaCe's `StateReachability` analysis pass. + +#### 3.4.3 is_previously_written() + +Determines if an array was written before a given node in a state. Used by data forwarding to determine if an intermediate value needs to be stored (because it will be overwritten). + +Checks both: +1. Current state (concurrent subgraphs) +2. Predecessor states + +--- + +### 3.5 Utility Functions: `utils.py` + +**Location**: [utils.py](utils.py) + +The `utils.py` module contains many helper functions organized into categories: + +#### 3.5.1 Descriptor Management + +- `add_backward_desc()`: Add gradient array descriptor to backward SDFG +- `add_backward_desc_for_connector()`: Add backward descriptor for specific connector +- Helper functions for managing array descriptors and data types + +#### 3.5.2 Symbolic Differentiation + +- `differentiate_tasklet()`: Symbolically differentiates tasklet code using AST parsing and SymPy +- Converts tasklet expressions to symbolic form, computes derivatives, and generates backward code + +#### 3.5.3 Graph Traversal + +- `get_all_path_edges()`: Gets all edges on paths from source to target +- `concurrent_subgraphs()`: Finds concurrent execution regions in a state +- Helper functions for navigating SDFG structures + +#### 3.5.4 Loop Analysis + +- `state_within_loop()`: Checks if a state is inside a loop region +- `get_loop_carried_dependencies()`: Finds arrays with loop-carried dependencies +- Loop-specific helper functions + +--- + +## 4. Data Forwarding System + +### 4.1 The Core Problem + +During backward pass generation, we often need access to intermediate values from the forward pass: + +**Example**: +```python +# Forward +y = sigmoid(x) +z = y * y +L = z # Identity loss function + +# Backward (to compute dL/dx) +dL/dy = dL/dz * 2y # Need y from forward pass! +dL/dx = dL/dy * y * (1 - y) # Need y again! +``` + +**Two strategies**: +1. **Store**: Save `y` during forward pass, load it during backward + - **Pro**: Fast backward pass (no recomputation) + - **Con**: High memory usage +2. **Recompute**: Recompute `y = sigmoid(x)` during the backward pass + - **Pro**: Low memory usage (no storage required) + - **Con**: Slower backward pass due to recomputation cost + +--- + +### 4.2 DataForwardingManager: `manager.py` + +**Location**: [data_forwarding/manager.py](data_forwarding/manager.py) + +Coordinates the data forwarding strategy. + +#### 4.2.1 Strategy Selection + +The manager provides three strategies: + +1. **`store_all`** (default): Store all intermediate values + - Fastest backward pass + - Highest memory usage + - Best for memory-rich environments + +2. **`recompute_all`**: Recompute all intermediate values + - Experimental feature to test recomputation capabilities + +3. **`user_defined`**: User specifies which arrays to recompute + - Balanced approach + - Requires domain knowledge + - Allows fine-grained control + +#### 4.2.2 Main Algorithm + +For each data item that needs to be forwarded: + +1. Determine if the data is overwritten before backward pass needs it +2. If overwritten, choose resolution strategy (store or recompute) +3. Apply strategy: + - **Store**: Create copy before overwrite, load in backward + - **Recompute**: Extract computation subgraph, inline in backward + +#### 4.2.3 Overwrite Detection Algorithm + +**Problem**: Determine if an intermediate value is overwritten before the backward pass needs it + +**Algorithm**: +``` +is_overwritten(array, state, node): + 1. Check if array is written in concurrent subgraphs + 2. Check if array is written in successor states + 3. If either is true, the array is overwritten + 4. Apply data forwarding strategy (store or recompute) +``` + +**Uses**: `is_previously_written()` from `analysis.py` + +--- + +### 4.3 Store Strategy: `store.py` + +**Location**: [data_forwarding/store.py](data_forwarding/store.py) + +**Key Function**: `resolve_overwrite_with_store()` + +**Approach**: + +``` +Forward Pass State Backward Pass State +┌──────────────┐ ┌──────────────┐ +│ Compute x │ │ │ +│ Store x_copy │ │ Load x_copy │ +│ Overwrite x │ │ Use in grad │ +└──────────────┘ └──────────────┘ +``` + +**Steps**: +1. Create a storage descriptor for the intermediate value +2. Add a copy operation in the forward state (before overwrite) +3. Add a load operation in the backward state (when needed) +4. Update memlets to use the stored copy + +--- + +### 4.4 Recompute Strategy: `recompute.py` (Experimental!) + +**Location**: [data_forwarding/recompute.py](data_forwarding/recompute.py) + +**Key Function**: `resolve_overwrite_with_recomputation()` + +**Approach**: + +``` +Forward Pass State Backward Pass State +┌──────────────┐ ┌──────────────┐ +│ Compute x │ │ Recompute x │ +│ │ → │ Use in grad │ +│ Overwrite x │ │ │ +└──────────────┘ └──────────────┘ +``` + +**Steps**: +1. Extract the computation subgraph that produces the value +2. Create a nested SDFG containing the recomputation logic +3. Inline the nested SDFG in the backward state +4. Connect inputs and outputs appropriately + +**Subgraph Extraction** (`get_recomputation_nsdfg()`): +- Performs backward breadth-first search (BFS) from the data node to find all dependencies +- Copies nodes and edges into a new nested SDFG +- Handles map scopes and connectors +- Ensures all dependencies are included + +--- + +## 5. Backward Implementations + +### 5.1 DaCe Core Nodes: `dace_nodes.py` + +**Location**: [implementations/dace_nodes.py](implementations/dace_nodes.py) + +Implements backward passes for core SDFG elements. + +#### 5.1.1 AccessNode + +**Purpose**: Create gradient AccessNode + +**Approach**: +- Forward: `AccessNode("x")` +- Backward: `AccessNode("grad_x")` with `setzero=True` + +Also handles view connectors for arrays with views or subsets. + +#### 5.1.2 Tasklet + +**Purpose**: Symbolically differentiate tasklet code + +**Approach**: +1. Parse tasklet code to AST +2. Extract output expressions +3. Use SymPy to compute symbolic derivatives +4. Generate backward code: `grad_input = grad_output * derivative` + +**Example**: +- Forward: `y = x * x` +- Backward: `grad_x = grad_y * 2 * x` + +#### 5.1.3 Maps + +**Purpose**: Reverse map structure + +Maps are special: `MapEntry` and `MapExit` nodes are swapped in the backward pass. + +**Forward**: +``` +AccessNode → MapEntry → [Tasklet in scope] → MapExit → AccessNode +``` + +**Backward**: +``` +AccessNode → MapEntry (reversed) → [Tasklet_grad in scope] → MapExit (reversed) → AccessNode +``` + +**Approach**: +- `MapEntry` → `MapExit` in backward pass +- `MapExit` → `MapEntry` in backward pass +- Connectors inverted: `IN_X` ↔ `OUT_X` +- Same map object used for both + +#### 5.1.4 NestedSDFG + +**Purpose**: Recursively differentiate nested SDFGs + +**Approach**: +1. Recursively call `add_backward_pass()` on nested SDFG +2. Map forward connectors to backward connectors +3. Handle symbols and interstate edges +4. Ensure proper gradient flow through nested boundaries + +#### 5.1.5 LoopRegions + +**Purpose**: Reverse loops in the forward SDFG + +**Approach**: +Loops are reversed by creating a backward loop that iterates in the reverse direction to process gradients. + + +``` +# Forward loop: +for i in range(N): + y[i+1] = f(x[i]) + +# Backward loop: +for i in reversed(range(N)): + grad_x[i] = grad_f(x[i]) * grad_y[i+1] +``` + +--- + +### 5.2 DaCe Reduction Nodes: `dace_reduction_nodes.py` + +**Location**: [implementations/dace_reduction_nodes.py](implementations/dace_reduction_nodes.py) + +Implements backward passes for DaCe reduction operations (307 lines). + +#### 5.2.1 Key Implementations + +| Operation | Backward Implementation | Notes | +|-----------|------------------------|-------| +| **Reduce (Sum)** | Broadcast gradient to match input shape | Handles axis reduction | +| **Reduce (Max/Min)** | Gradient flows only to max/min elements | Requires forward values | + +--- + +### 5.3 ONNX Operations: `onnx_ops.py` + +**Location**: [implementations/onnx_ops.py](implementations/onnx_ops.py) + +Implements backward passes for 50+ ONNX operations. Each implementation follows the ONNX operator specification for gradient computation. + +**Categories**: + +- **Element-wise**: Add, Sub, Mul, Div, Sqrt, Exp, Log, Pow, etc. +- **Activation**: Relu, Sigmoid, Tanh, Softmax, etc. +- **Matrix**: MatMul, Gemm, BatchMatMul +- **Convolution**: Conv, ConvTranspose +- **Pooling**: MaxPool, AveragePool, GlobalAveragePool +- **Normalization**: BatchNormalization, LayerNormalization +- **Reduction**: ReduceSum, ReduceMean, ReduceMax, etc. +- **Shape**: Reshape, Transpose, Concat, Split, Squeeze, Unsqueeze +- **Advanced**: Gather, Scatter, Einsum, etc. + +Each ONNX backward implementation is registered with `@dace.registry.autoregister_params(op="OpName")`. + +--- + +### 5.4 PyTorch Operations: `pytorch_ops.py` + +**Location**: [implementations/pytorch_ops.py](implementations/pytorch_ops.py) + +Implements backward passes using PyTorch's optimized CUDA kernels (128 lines). + +#### 5.4.1 Key Implementations + +| Operation | Backward Implementation | Notes | +|-----------|------------------------|-------| +| **Conv (depthwise)** | `PyTorchConvBackward` | Uses `at::thnn_conv_depthwise2d_backward_out` | + +This implementation leverages PyTorch's C++ ATen library for GPU-accelerated depthwise convolution backward passes. + +--- + +## 6. Library Integration + +### 6.1 BackwardPass Library Node: `library.py` + +**Location**: [library/library.py](library/library.py) + +Provides a library node for encapsulating backward passes as reusable components. + +#### 6.1.1 ParameterArray + +A special data descriptor for gradient accumulation buffers that mimics PyTorch Parameters. + +#### 6.1.2 BackwardPass + +A library node that wraps a backward pass SDFG, allowing backward passes to be composed and reused like other library operations. + +#### 6.1.3 ExpandBackwardPass + +Expands the `BackwardPass` library node into the full SDFG. Handles: +- Gradient initialization (zero or provided seed) +- Parameter gradient accumulation + +--- + +## 7. PyTorch Integration + +### 7.1 Overview: `torch.py` + +**Location**: [torch.py](torch.py) + +Enables the integration between DaCe AD and PyTorch's autograd system. + +### 7.2 make_backward_function() + +**Purpose**: Convert ONNX model to PyTorch-differentiable function + +**Signature**: +```python +def make_backward_function( + forward_sdfg: SDFG, + inputs: List[str], + outputs: List[str], + parameters: Optional[List[str]] = None +) -> Type[torch.autograd.Function]: +``` + +**Returns**: PyTorch `autograd.Function` subclass with: +- `forward()`: Compiles and runs forward SDFG +- `backward()`: Compiles and runs backward SDFG +- Handles PyTorch tensor ↔ DaCe array conversion +- Supports scalar inputs/outputs +- Manages parameter gradients + +### 7.3 Integration Flow + +``` +PyTorch Model + ↓ +DaCe ONNX Import + ↓ +Forward SDFG + ↓ +add_backward_pass() + ↓ +Backward SDFG + ↓ +make_backward_function() + ↓ +torch.autograd.Function + ↓ +Use in PyTorch training loop +``` + +--- + +## 8. Gradient Accumulation and Clearing + +### 8.1 Gradient Accumulation + +**Problem**: Multiple paths can contribute to same gradient + +**Example**: +``` + ┌─→ y1 ─┐ + x ──┤ ├─→ z + └─→ y2 ─┘ +``` + +Both `y1` and `y2` contribute to `grad_x`. + +**Solution**: Write-Conflict Resolution (WCR) + +When connecting gradients, use WCR on memlets: +```python +memlet.wcr = "lambda a, b: a + b" +``` + +This ensures multiple gradient contributions are summed correctly. + +### 8.2 Gradient Clearing + +**Problem**: Overwritten arrays in the forward pass require clearing the gradients of the corresponding gradient arrays to allow the always-accumulate solution presented above. + +**When to Clear Gradients**: +- In the backward pass, at the corresponding point where arrays in the forward pass were overwritten. + +**Implementation Strategies**: + +1. **Zero Initialization for all intermediate arrays**: Set all gradient arrays to zero before backward pass + ```python + # In DaCe, gradient arrays can be initialized with setzero=True + grad_array = AccessNode("grad_x", setzero=True) + ``` + +2. **Manual Clearing**: Explicitly zero out gradient arrays if necessary + ```python + # Reset gradients if an overwrite is detected in dace/autodiff/backward_pass_generator.py + self._zero_out_gradient(forward_state=forward_state, + forward_node=node, + memlet=edge.data) + ``` diff --git a/dace/autodiff/autodiff.py b/dace/autodiff/autodiff.py new file mode 100644 index 0000000000..cfd686c224 --- /dev/null +++ b/dace/autodiff/autodiff.py @@ -0,0 +1,83 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import List, Union, Optional + +from dace.autodiff.backward_pass_generator import BackwardPassGenerator + +from dace.sdfg import SDFG, nodes +from dace.sdfg.utils import inline_control_flow_regions +from dace.sdfg.state import LoopRegion + + +def add_backward_pass(sdfg: SDFG, + outputs: List[Union[nodes.AccessNode, str]], + inputs: List[Union[nodes.AccessNode, str]], + data_forwarding_strategy: str = "store_all", + data_to_recompute: Optional[List[str]] = None, + simplify: bool = True, + separate_sdfgs: bool = False) -> Optional[SDFG]: + """ Experimental: Add a backward pass to `state` using reverse-mode automatic differentiation. + + ``inputs``, ``outputs`` and ``grads`` can be provided either as ``AccessNode`` nodes, or as ``str``, in which + case the graph will be searched for exactly one matching ``AccessNode`` with data matching the ``str``. + + The SDFG may contain the following nodes: + + * Maps + * AccessNodes + * Reductions (Sum, Min, Max) + * ONNXOps + * Multiple states + * LoopRegions + * NestedSDFGs (subject to the same constraints) + + When differentiating an :class:`~dace.libraries.onnx.nodes.onnx_op.ONNXOp`, the ONNXBackward registry will be checked + for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for + matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be + attempted. If this fails, or no pure forward implementation is found, the method will fail. + + .. note:: + This function modifies the input SDFG in-place. Even if ``separate_sdfgs`` is ``True``, modifications + such as storing intermediate results and inlining ControlFlowRegions can be applied to the original SDFG. + + :param sdfg: the SDFG to add the backward pass to. + :param outputs: the forward pass outputs of the function to differentiate. + :param inputs: the inputs w.r.t. which the gradient will be returned. + :param data_forwarding_strategy: strategy for forwarding data to the backward pass. Could be one of: + * "store_all": store all intermediate data (default, uses most memory, fastest). + * "recompute_all": recompute all intermediate data. + * "user_defined": store all intermediates except for ones specified in `data_to_recompute`. + :param data_to_recompute: list of data arrays to recompute instead of storing. Only used if + `data_forwarding_strategy` is "user_defined". + :param simplify: whether to apply the simplify pass to the forward and backward SDFGs. + :param separate_sdfgs: whether to create a separate SDFG for the backward pass. + :return: the backward SDFG if separate_sdfgs is True, the original SDFG (which now also contains the backward pass) otherwise. + """ + # Validate the SDFG + sdfg.validate() + + if simplify: + sdfg.simplify() + + # Inline conditional blocks but keep loops + inline_control_flow_regions(sdfg, ignore_region_types=[LoopRegion]) + + if separate_sdfgs: + backward_sdfg = SDFG(sdfg.name + "_backward") + else: + backward_sdfg = sdfg + + # Add backward pass + gen = BackwardPassGenerator(sdfg=sdfg, + given_gradients=outputs, + required_gradients=inputs, + backward_sdfg=backward_sdfg, + data_forwarding_strategy=data_forwarding_strategy, + data_to_recompute=data_to_recompute) + gen.backward() + sdfg.validate() + + if simplify: + sdfg.simplify() + sdfg.validate() + + return backward_sdfg diff --git a/dace/autodiff/backward_pass_generator.py b/dace/autodiff/backward_pass_generator.py new file mode 100644 index 0000000000..6857749643 --- /dev/null +++ b/dace/autodiff/backward_pass_generator.py @@ -0,0 +1,2056 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple, Set, Dict, Union, Optional, Sequence +import sympy as sp + +# DaCe imports +import dace +from dace.properties import CodeBlock +import dace.sdfg.nodes as nodes +import dace.transformation.transformation as xf +from dace import dtypes, data as dt +from dace.sdfg import SDFG, SDFGState, state as dstate, utils as dace_utils +from dace.sdfg.state import LoopRegion +from dace.memlet import Memlet + +try: + from dace.libraries.onnx.forward_implementation_abc import ONNXForward + from dace.libraries.onnx.nodes.onnx_op import ONNXOp + ONNX_AVAILABLE = True +except ImportError: + ONNXForward = None + ONNXOp = None + ONNX_AVAILABLE = False + +# Autodiff imports +from dace.autodiff.base_abc import (BackwardContext, BackwardResult, AutoDiffException, find_backward_implementation, + ExpansionTemplate) +import dace.autodiff.utils as ad_utils +from dace.autodiff.implementations.dace_nodes import DaceNodeBackwardImplementations +from dace.autodiff.data_forwarding.manager import DataForwardingManager + + +class BackwardPassGenerator: + """Generator for automatic differentiation backward passes on DaCe SDFGs. + + This class orchestrates the creation of backward passes for automatic differentiation + using reverse-mode AD. It handles gradient computation, data forwarding between + forward and backward passes, and complex control flow structures. + + :param sdfg: The forward SDFG to differentiate. + :param given_gradients: Output arrays for which gradients are provided (seed gradients). + :param required_gradients: Input arrays for which gradients should be computed. + :param backward_sdfg: SDFG to contain the backward pass. Can be same as forward SDFG. + :param array_grad_map: Optional mapping from array names to gradient array names. + :param conflicted_gradient_buffers: Arrays with potential write conflicts requiring special handling. + :param data_forwarding_strategy: Strategy for forwarding data ('store_all', 'recompute_all', 'user_defined'). + :param data_to_recompute: Arrays to recompute instead of storing (when strategy='user_defined'). + :raises AutoDiffException: If the backward pass generation fails. + + Example:: + + gen = BackwardPassGenerator( + sdfg=forward_sdfg, + given_gradients=['loss'], + required_gradients=['weights', 'input'] + ) + gen.backward() + """ + + def __init__( + self, + *, + sdfg: SDFG, + given_gradients: Sequence[Union[nodes.AccessNode, str]], + required_gradients: Sequence[Union[nodes.AccessNode, str]], + backward_sdfg: SDFG, # This can be the same as sdfg + array_grad_map: Optional[Dict[str, str]] = None, + conflicted_gradient_buffers: Optional[Set[str]] = None, + data_forwarding_strategy: str = "store_all", + data_to_recompute: Optional[List[str]] = None, + ): + + self.sdfg: SDFG = sdfg + self.data_to_recompute = data_to_recompute + self.backward_sdfg: SDFG = backward_sdfg + + given_gradients = [ + n if isinstance(n, nodes.AccessNode) else self._str_to_access(n, "outputs") for n in given_gradients + ] + required_gradients = [ + n if isinstance(n, nodes.AccessNode) else self._str_to_access(n, "inputs") for n in required_gradients + ] + required_gradients = [n for n in required_gradients if n is not None] + + self.given_gradients_data = {n.data for n in given_gradients} + self.required_gradients_data = {n.data for n in required_gradients} + + self.input_names = {n.data for n in required_gradients} + self.output_names = {n.data for n in given_gradients} + + #: Arrays descriptors for the gradients + self.backward_grad_arrays: Dict[str, dt.Array] = {} + + #: Arrays descriptors for inputs that are required from the forward pass + self.backward_input_arrays: Dict[str, dt.Array] = {} + + #: Mapping from forward node -> backward node, and forward map -> backward map + self.reverse_map: Dict[nodes.Node, Union[nodes.Node, nodes.Map]] = {} + + #: Mapping from forward state -> backward state + self.reversed_states_map: Dict[SDFGState, SDFGState] = {} + + #: Mapping from forward LoopRegion -> backward LoopRegion + self.reversed_loops_map: Dict[LoopRegion, LoopRegion] = {} + + #: Mapping from forward state -> backward state for loop states + self.reversed_loop_states_map: Dict[nodes.Node, nodes.Node] = {} + + #: Mapping between states and their subgraph views for AD processing + self.states_view_map: Dict[SDFGState, dstate.StateSubgraphView] = {} + + #: Mapping between loop states and their subgraph views for AD processing + self.loop_states_view_map: Dict[SDFGState, dstate.StateSubgraphView] = {} + + #: Mapping between the map entry of a conditional assignment block and its zero-out AN + self.conditional_block_entry: Dict[nodes.MapEntry, nodes.AccessNode] = {} + + #: Mapping from forward_node -> BackwardResult for that node + self.result_map: Dict[nodes.Node, BackwardResult] = {} + + #: Mapping from forward name to gradient name for arrays + self.array_grad_map: Dict[str, str] = array_grad_map or {} + + #: Mapping from the backward access nodes that will be zeroed out + # to the transients that contain the values before they are zeroed out + self.zeroed_out: Dict[nodes.AccessNode, List[nodes.AccessNode]] = {} + + #: The read-only arrays of the forward SDFG. Used in data forwarding decisions + self.read_only_arrays: Set[str] = ad_utils.get_read_only_arrays(self.sdfg) + + #: Mapping from overwritten input name to storing AccessNode + self.stored_inputs: Dict[str, nodes.AccessNode] = {} + + # Variable to check if backward has already been applied + self._applied = False + + self.data_forwarding_strategy = data_forwarding_strategy + + # Topological ordering of the states + self.state_order = ad_utils.get_state_topological_order(self.sdfg) + self.conflicted_gradient_buffers: Set[str] = conflicted_gradient_buffers or set() + + self.interstate_symbols: Dict[str, str] = {} + for edge in self.sdfg.all_interstate_edges(): + for assign_symbol, assignment in edge.data.assignments.items(): + self.interstate_symbols[assign_symbol] = assignment + + # Validate parameters and setup SDFG configuration + self._validate_gradients() + self._setup_sdfg_configuration(sdfg, backward_sdfg, given_gradients) + + # DaCe nodes backward implementations + self.dace_node_impl = DaceNodeBackwardImplementations(self) + + #: List containing information about all the data to be forwarded to the backward pass + self.data_to_forward: List[Tuple[SDFGState, SDFGState, nodes.AccessNode, nodes.Node, + dstate.MultiConnectorEdge]] = [] + + # Data forwarding manager + self.data_forwarding_manager = DataForwardingManager(self) + + def _validate_gradients(self) -> None: + """Validate that gradient arrays exist in the SDFG. + + Raises: + AutoDiffException: If gradient arrays are not found in SDFG arrays. + """ + # Check outputs (given gradients) + for outp in self.given_gradients_data: + if outp not in self.sdfg.arrays: + raise AutoDiffException(f"Could not find output '{outp}' in SDFG array descriptors") + + # Check inputs (required gradients) + for inp in self.required_gradients_data: + if inp not in self.sdfg.arrays: + raise AutoDiffException(f"Could not find input '{inp}' in SDFG array descriptors") + + def _setup_sdfg_configuration(self, sdfg: SDFG, backward_sdfg: SDFG, + given_gradients: List[nodes.AccessNode]) -> None: + """Setup SDFG configuration for separate or combined forward/backward passes. + + :param sdfg: Forward SDFG. + :param backward_sdfg: Backward SDFG. + :param given_gradients: List of gradient output nodes. + :raises AutoDiffException: If configuration is invalid for combined SDFG mode. + """ + if sdfg is backward_sdfg: + # Combined mode requires single scalar output + if len(given_gradients) != 1: + raise AutoDiffException("When forward and backward SDFGs are the same, exactly one output is required, " + f"got {len(given_gradients)}") + + output_array = sdfg.arrays[given_gradients[0].data] + if not ad_utils.is_int_eq_value(output_array.total_size, 1): + raise AutoDiffException("When forward and backward SDFGs are the same, output must be a single scalar") + + self.separate_sdfgs = False + else: + self.separate_sdfgs = True + + def create_child_generator(self, **kwargs) -> 'BackwardPassGenerator': + """Create a child generator for nested SDFG differentiation. + + This factory method creates a new BackwardPassGenerator instance for differentiating + nested SDFGs, propagating relevant configuration from the parent generator. + + :param kwargs: Parameters to pass to the child generator constructor. + Required: sdfg, given_gradients, required_gradients, backward_sdfg. + :return: A new BackwardPassGenerator instance configured for the nested SDFG. + """ + defaults = { + 'data_forwarding_strategy': self.data_forwarding_strategy, + 'data_to_recompute': self.data_to_recompute, + } + defaults.update(kwargs) + return BackwardPassGenerator(**defaults) + + def backward(self) -> Tuple[BackwardResult, Dict[str, dt.Array], Dict[str, dt.Array]]: + """Generate the backward pass in backward_sdfg.""" + return self.reverse_sdfg() + + def reverse_sdfg(self) -> Tuple[BackwardResult, Dict[str, dt.Array], Dict[str, dt.Array]]: + """Generate the backward pass by reversing all SDFG states. + + Processes all states in the SDFG and creates their backward counterparts, + connecting them with appropriate control flow for gradient computation. + + :return: A tuple containing: + + * ``BackwardResult`` - Contains gradient mappings and metadata. + * ``Dict[str, dt.Array]`` - Gradient array descriptors (backward pass outputs). + * ``Dict[str, dt.Array]`` - Forward pass arrays required by backward pass. + + :raises AutoDiffException: If backward pass was already applied to this generator. + """ + + if self._applied: + raise AutoDiffException("Backward may only be called once. Instantiate a new BackwardPassGenerator.") + + # Create state views mapping and expand all the SDFG nodes + self._create_stateviews_mapping() + + # Reverse each state in the graph + self._reverse_states() + + # Connect the new reversed states to the other states correctly + self._connect_reversed_states() + + # Fill the interstate edges with the correct conditions + self._fill_interstate_edge_conditions() + + # Add interstate assignments for control flow decisions + self._add_interstate_edge_assignments() + + # Forward required data by the backward pass according to a user defined strategy + self.data_forwarding_manager.forward_data_to_backward_pass() + + # In some cases (accessnode -> accessnode), the descriptors for the gradients of the function outputs are not + # added yet. Add them now. + for given_grad in sorted(self.given_gradients_data): + if self.array_grad_name(given_grad) not in self.backward_sdfg.arrays: + self._add_gradient_data_descriptor(given_grad) + + # Prepare the output + required_grad_names = {name: self.array_grad_name(name) for name in self.required_gradients_data} + given_grad_names = {name: self.array_grad_name(name) for name in self.given_gradients_data} + + # Set mapping from gradient name to whether it should be zeroed out on initialization + zero_init: Dict[str, bool] = {} + for node, bres in self.result_map.items(): + forward_state = self._get_node_state(node=node) + for zname, zinit in bres.zero_init.items(): + # Reverse lookup + cname = next(k for k, v in bres.required_grad_names.items() if v == zname) + + for e in forward_state.in_edges_by_connector(node, cname): + zero_init[e.data.data] = zinit + for e in forward_state.out_edges_by_connector(node, cname): + zero_init[e.data.data] = zinit + + self._applied = True + result = BackwardResult(required_grad_names=required_grad_names, + given_grad_names=given_grad_names, + zero_init=zero_init) + return result, self.backward_grad_arrays, self.backward_input_arrays + + def _create_stateviews_mapping(self) -> None: + """Map each state in the SDFG to views that indicate what to differentiate.""" + self._find_subgraph_to_differentiate() + # Expand until there is nothing left to expand + while self._expand_nodes(): + # Nodes have been expanded again on the expanded graph; recalculate the forward graph + self._find_subgraph_to_differentiate() + + def _reverse_states(self) -> None: + """Go through all states of the forward SDFG, reverse them and add them to the backward SDFG.""" + # For reversal we want to iterate through the states in reverse topological order + for state in reversed(self.state_order): + # Get all the views of this state + if state not in self.states_view_map: + raise AutoDiffException(f"State {state} not found in states view map") + state_subgraph_views = [self.states_view_map[state]] + + # In case this is a state loop + state_subgraph_loop_view = [] + if state in self.loop_states_view_map: + loop_view = self.loop_states_view_map[state] + state_subgraph_loop_view.append(loop_view) + + for state_subgraph_view in state_subgraph_views: + + # Make sure this state has not already been reversed + if state in self.reversed_states_map: + raise AutoDiffException(f"State {state} has already been reversed") + + # Create the new reversed state label + if state_subgraph_view in state_subgraph_loop_view: + reversed_state_label = f"{state.label}_loop_reversed" if state.label else None + else: + reversed_state_label = f"{state.label}_reversed" if state.label else None + + # Create new state for reversal + # At the moment we add all states to the backward_sdfg directly + # This will later be modified when connecting the states + reversed_state = self.backward_sdfg.add_state(label=reversed_state_label) + + # Add the new state to the reversed map dict + if state_subgraph_view in state_subgraph_loop_view: + self.reversed_loop_states_map[state] = reversed_state + else: + self.reversed_states_map[state] = reversed_state + + # Check that all edges are float, int, or boolean + ad_utils.check_edges_type_in_state(state_subgraph_view) + + # Recursively reverse the subgraph + self._reverse_subgraph(forward_state=state, backward_state=reversed_state, subgraph=state_subgraph_view) + + # We also reverse all the LoopRegions in the graph + for node in self.sdfg.nodes(): + if not isinstance(node, LoopRegion): + continue + self._reverse_loop_region(node) + + def _connect_reversed_states(self) -> None: + """Connect backward states corresponding to forward SDFG states. + + All incoming edges of a forward state become outgoing edges in the backward SDFG. + """ + + for state in self.state_order: + # All states should be reversed already + if state not in self.reversed_states_map: + raise AutoDiffException(f"State {state} not found in reversed states map") + backward_state = self.reversed_states_map[state] + + # Get all the out edges of the forward state + parent_graph = state.parent_graph + state_out_edges = parent_graph.out_edges(state) + + # If there are no outgoing connections + if len(state_out_edges) == 0: + # This is an end-state and it needs to be connected to its reversed state + # we do this only if the backward sdfg is the same as the forward one + if parent_graph == self.sdfg and not self.separate_sdfgs: + self.backward_sdfg.add_edge(src=state, dst=backward_state, data=dace.InterstateEdge()) + + # Get all the in connections of the forward state + forward_state_in_edges = parent_graph.in_edges(state) + + # Get the backward state again + # We need to do this in case the state is linked to an initialization state + # For outgoing edges, we connect the actual state not its initialization + backward_state = self.reversed_states_map[state] + + for edge in forward_state_in_edges: + # Each incoming edge to a forward state will add an outgoing edge to a backward state + fwd_src = edge.src + if isinstance(fwd_src, SDFGState): + bwd_src = self.reversed_states_map[fwd_src] + elif isinstance(fwd_src, LoopRegion): + bwd_src = self.reversed_loops_map[fwd_src] + + graph = bwd_src.parent_graph + graph.add_edge(src=backward_state, dst=bwd_src, data=dace.InterstateEdge()) + + # Connect all the loops + for loop in self.reversed_loops_map.keys(): + + # Get the loop parent + parent_graph = loop.parent_graph + + # Get the reversed loop + reversed_loop = self.reversed_loops_map[loop] + + # Get all the out edges of the forward state + loop_out_edges = parent_graph.out_edges(loop) + + # If there are no outgoing connections + if len(loop_out_edges) == 0: + # This is an end-region and it needs to be connected to its reversed region + # We do this only if the backward sdfg is the same as the forward one + if parent_graph == self.sdfg and not self.separate_sdfgs: + self.backward_sdfg.add_edge(src=state, dst=backward_state, data=dace.InterstateEdge()) + + # Get all the in edges + loop_in_edges = parent_graph.in_edges(loop) + + for edge in loop_in_edges: + + # A loop region could be connected to a state or another loop region + fwd_src = edge.src + if isinstance(fwd_src, SDFGState): + bwd_src = self.reversed_states_map[fwd_src] + elif isinstance(fwd_src, LoopRegion): + bwd_src = self.reversed_loops_map[fwd_src] + + # Get the graph to add the edge to + if isinstance(parent_graph, LoopRegion): + bwd_parent_graph = self.reversed_loops_map[parent_graph] + else: + bwd_parent_graph = self.backward_sdfg + + bwd_parent_graph.add_edge(src=reversed_loop, dst=bwd_src, data=dace.InterstateEdge()) + + def _fill_interstate_edge_conditions_in_scope(self, graph: Union[SDFG, LoopRegion]) -> None: + """ + Get all the nodes within this graph in topological order, + Connect the states and call the function recursively on the nested scopes. + """ + # A dictionary that keeps track of the conditions necessary to reach a state in the forward passs + conditions_map: dict[SDFGState, str] = {} + + # Iterate through all the nodes in topological order + nodes = dace_utils.dfs_topological_sort(graph, graph.source_nodes()) + for node in nodes: + # A list of the conditions on all the in edges for this state + in_edges_conditions: List[str] = [] + if isinstance(node, SDFG) or isinstance(node, LoopRegion): + # if this is not a reversed loop region + if not node in self.reversed_loops_map: + continue + self._fill_interstate_edge_conditions_in_scope(node) + else: + + if not isinstance(node, SDFGState): + raise AutoDiffException(f"Expected SDFGState, got {type(node)}") + forward_state = node + parent_graph = forward_state.parent_graph + + # if this is not a reversed state + if node not in self.reversed_states_map: + continue + + # We will iterate through all the incoming edges to the forward state + edges_list = parent_graph.in_edges(forward_state) + + # If there are none, this is a start state + # If there is only one incoming edge, no condition necessary + if len(edges_list) < 2: + conditions_map[forward_state] = "1" + + for edge in edges_list: + # Get the src state + src_state = edge.src + + # Get the condition to get to the source state in the forward pass + src_state_condition = conditions_map[src_state] + + # Add the condition in the current edge + current_edge_condition = edge.data.condition.as_string + + # New backward edge condition + # Handle "1" (unconditional) to avoid creating expressions like "1 and condition" + if src_state_condition == "1" and current_edge_condition == "1": + new_bwd_edge_condition = "1" + elif src_state_condition == "1": + new_bwd_edge_condition = current_edge_condition + elif current_edge_condition == "1": + new_bwd_edge_condition = src_state_condition + else: + new_bwd_edge_condition = f"({src_state_condition}) and ({current_edge_condition})" + + bwd_edge = self._get_backward_state_edge(edge) + + # Add the condition to the edge + bwd_edge.data.condition = CodeBlock(new_bwd_edge_condition) + + # If there is a special case for the first iteration of the backward state + if forward_state in self.loop_states_view_map: + + # Get the corresponding edge between the loop states + bwd_loop_edge = self._get_backward_loop_state_edge(edge) + + # Add the same condition to the edge + bwd_loop_edge.data.condition = CodeBlock(new_bwd_edge_condition) + + # Add the forward condition to the list to update the conditions_map dict + if new_bwd_edge_condition != "1": + # Only add the condition if it exists + in_edges_conditions.append(new_bwd_edge_condition) + + # Update the conditions mapping + # This will be the logical or of all the saved conditions + # because we can reach this state by taking any of the incoming edges + if len(in_edges_conditions) == 0: + condition_for_state = "1" + else: + condition_for_state = in_edges_conditions[0] + for i in range(1, len(in_edges_conditions)): + condition_for_state += f" or {in_edges_conditions[i]}" + + # Since we are doing topological sort before iterating + conditions_map[node] = condition_for_state + + def _fill_interstate_edge_conditions(self) -> None: + """ + Go through all of the states in the forward graph and fill the necessary conditions in the backward states. + Each edge in the backward SDFG will be the logical AND between the equivalent edge in the forward SDFG and + all of the conditions that are necessary to get to this state in the forward pass. + """ + self._fill_interstate_edge_conditions_in_scope(self.sdfg) + + # Iterate through all the loop regions and connect the loop states if necessary + for loop in self.sdfg.all_control_flow_regions(): + # Only iterate over loop regions + if not isinstance(loop, LoopRegion): + continue + # Get the start state + loop_start_state = loop.start_block + if not isinstance(loop_start_state, SDFGState): + # This would be the case for perfectly nested loops + # Nothing to do in this case + continue + + if not loop_start_state in self.reversed_loop_states_map: + # There are no extra states to connect + continue + + # If there are loop states to connect + # Prepare the condition for the new state + loop_it = loop.loop_variable + reversed_loop = self.reversed_loops_map[loop] + start, _ = self._extract_loop_region_info(reversed_loop) + + # We only want the loop state to execute + # in the first iteration of the reversed loop + first_state_condition = f"{loop_it} == {start}" + first_state_condition = CodeBlock(first_state_condition) + + leftover_loop_state = self.reversed_loop_states_map[loop_start_state] + + # Get the reversed loop start state + reversed_loop_start_state = self.reversed_states_map[loop_start_state] + + # Add a state to the reversed loop region + new_start_state = reversed_loop.add_state_before(reversed_loop_start_state, + is_start_block=True, + condition=first_state_condition) + + # The condition for this interstate edge should be all iterations expect the fist + leftover_iterations_condition = f"not {first_state_condition.as_string}" + + # Add a connection between this new start state and the first iteration state + reversed_loop.add_edge(src=new_start_state, + dst=leftover_loop_state, + data=dace.InterstateEdge(condition=leftover_iterations_condition)) + + def _add_interstate_edge_assignments(self) -> None: + """ + We will need to add interstate assignments at the start of the backward SDFG + This is necessary to make sure the control flow in the backward pass is correctly preserved. + """ + # We will add an empty state to the backward pass which will have all the assignments + + new_assignments = {} + # Get all the interstate edges in the forward sdfg + for edge in self.sdfg.all_interstate_edges(): + if edge.data.assignments: + # There are assignments to be added to the start of the backward pass + new_assignments = {**new_assignments, **edge.data.assignments} + + # We need to check if any data needs to be used in these assignment + # This is important in the case of a NSDFG where data will need to be forwarded + for _, rhs in edge.data.assignments.items(): + # If any of the sdfg arrays are in the rhs assignment + assignment_arrays = [array for array in self.sdfg.arrays.keys() if array in rhs] + if assignment_arrays and self.separate_sdfgs: + # We need to forward this data to the backward pass + for array in assignment_arrays: + if array not in self.backward_input_arrays: + self.backward_input_arrays[array] = self.sdfg.arrays[array] + # Special case if this is a symbol that is doesn't have a descriptor yet + if array not in self.backward_sdfg.arrays: + # We add it now + self.backward_sdfg.add_datadesc(array, copy.deepcopy(self.sdfg.arrays[array])) + + if new_assignments: + # Add the new state to the backward pass + # First we get the start block of the backward pass + if self.separate_sdfgs: + bwd_start_block = self.backward_sdfg.start_block + else: + fwd_start_state = self.sdfg.start_block + if isinstance(fwd_start_state, LoopRegion): + bwd_start_block = self.reversed_loops_map[fwd_start_state] + elif isinstance(fwd_start_state, SDFGState): + bwd_start_block = self.reversed_states_map[fwd_start_state] + else: + raise AutoDiffException("Need to add an assignments state but can't find the start block") + # TODO would this work on a loop region? + self.backward_sdfg.add_state_before(state=bwd_start_block, + label="_bwd_interstate_assignments_state", + assignments=new_assignments) + + def is_within_map(self, state: SDFGState, node: nodes.AccessNode) -> bool: + # Get the scope dictionary for the state + scope_dict = state.scope_dict() + + # Check if the node is within the scope of a map + scope_entry = scope_dict.get(node, None) + while scope_entry is not None: + if isinstance(scope_entry, nodes.MapEntry): + return True + scope_entry = scope_dict.get(scope_entry, None) + + return False + + def _zero_out_gradient(self, forward_state: SDFGState, forward_node: nodes.AccessNode, memlet: Memlet) -> None: + """ + Zero out gradients for overwritten arrays in the forward pass. + + Overwritten arrays need their gradients zeroed for gradient accumulation + to work correctly. This method: + + 1. Copies current gradient values to a temporary array (for one last use + in the backward pass) + 2. Zeros out the overwritten access in the backward pass + 3. Updates the read mapping to use the temporary instead of the original + + The operation is skipped when possible to optimize performance. + + :param forward_state: The state in the forward pass containing the write. + :param forward_node: The access node being overwritten. + :param memlet: The memlet describing the write operation. + """ + # Extra checks to only do this if necessary + # If this access node is not written to in the forward pass except for this one time, we don't need to zero it out + # An exception is made for required gradients that can be read outside the scope of the SDFG + clear_out_gradients = forward_node.data in self.required_gradients_data + + # Get the write instances in the forward sdfg to this node that happen in states before the current state + # These will represent the reads that will happen after this AccessNode + # This should avoid unnecessary zeroing out of dace generated temporaries + for state in self.state_order[0:self.state_order.index(forward_state) + 1]: + state_view = self.states_view_map[state] + for node, parent in state_view.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == forward_node.data: + if parent.in_degree(node) > 0: + # We need to check if the the forward node is inside a map scope or a LoopRegion + within_loop, _ = ad_utils.state_within_loop(state) + within_map = self.is_within_map(state, node) + if node != forward_node or (node == forward_node and (within_loop or within_map)): + clear_out_gradients = True + break + + # We can avoid clearing out the gradients + if not clear_out_gradients: + return + + # Get the backward state + backward_state: SDFGState = self.reversed_states_map[forward_state] + + # Get the backward node + backward_node: nodes.AccessNode = self.reverse_map[forward_node] + + # Get the original array + array_desc = self.backward_sdfg.arrays[backward_node.data] + + if dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, array_desc.storage): + cuda = False + elif dtypes.can_access(dtypes.ScheduleType.GPU_Default, array_desc.storage): + cuda = True + else: + raise ValueError(f"Unsupported storage {array_desc.storage}") + + # Careful! The order of the ifs here matters since ArrayView is a subclass of Array + if isinstance(array_desc, dt.View): + # No need to initialize: the viewed array will always be visited + # (since a view can never be a required grad), and thus the viewed array will be initialized. + pass + elif isinstance(array_desc, (dt.Array, dt.Scalar)): + # Create a new memlet to write to the gradient arrays + map_exit_memlet = copy.deepcopy(memlet) + map_exit_memlet.data = backward_node.data + + # Create the tasklet to zero out only the section in the memlet + # First, Get the range that the zeroout map should iterate over + # TODO: We are looking at writes in the forward pass, + # We should take the dst_subset of the memlet + # Are there cases where dst_subset is None? + ranges = [] + for iteration in map_exit_memlet.dst_subset: + if isinstance(iteration, tuple): + # The end of the range is inclusive in the loop + # We add 1 to get the upper bound for the map + ranges.append((iteration[0], iteration[1] + 1)) + elif isinstance(iteration, sp.Number): + # This covers the case of a single element being written + ranges.append((int(iteration), int(iteration) + 1)) + else: + raise AutoDiffException(f"Unsupported subset type {type(iteration)} in memlet {memlet}") + + # Create the indices dict + indices = {f"i{i}": f"{start}:{end}" for i, (start, end) in enumerate(ranges)} + + # Create the tasklet memlet from the indices + tasklet_memlet = dace.Memlet.simple(backward_node.data, ", ".join(indices.keys())) + + # Create the tasklet + _, map_entry, map_exit = backward_state.add_mapped_tasklet( + "_clear_" + backward_node.data + "_", + indices, {}, + f"__out = 0", { + "__out": tasklet_memlet, + }, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=True) + + # Get the edge from the map exit to the backward node + edge = backward_state.out_edges(map_exit)[0] + + # Get the cleared out AN + cleared_out_node = edge.dst + if not isinstance(cleared_out_node, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as cleared out node, got {type(cleared_out_node)}") + + # Create a copy of new memlet that will keep its other subset + # We want to copy the elements to their same indices in the new tmp array + # Create a new memlet that copies what memlet is writing to to the tmp + new_memlet_subset = memlet.subset if memlet.data == forward_node.data else memlet.other_subset + original_to_tmp_memlet = dace.Memlet(data=backward_node.data, + subset=new_memlet_subset, + other_subset=new_memlet_subset) + + # Remove the src_subset of the new memlet and replace the memlet in the edge + map_exit_memlet.subset = memlet.subset if memlet.data == forward_node.data else memlet.other_subset + map_exit_memlet.other_subset = None + edge.data = map_exit_memlet + + # Add an edge from the backward_node to the new map entry + backward_state.add_edge(backward_node, None, map_entry, None, dace.Memlet()) + + # A race will happen unless we make sure the data is being copied being it is zeroed out + # There is a read from the same array + # We need to add a transient that reads the content from forward pass before it is zeroed out + # Create a new array descriptor for the transient + transient_desc = copy.deepcopy(array_desc) + transient_desc.transient = True + + # Add the new array to the sdfg + transient_name = self.array_grad_name(forward_node.data) + "_tmp" + + # Check if the array is already in the backward sdfg + if transient_name not in self.backward_sdfg.arrays: + self.backward_sdfg.add_datadesc(transient_name, transient_desc) + + # Create an AcessNode for this transient and add it to backward state + transient_node = backward_state.add_read(transient_name) + + # Add a read from the backward node to the transient + backward_state.add_edge(backward_node, None, transient_node, None, original_to_tmp_memlet) + + # Add an empty edge from the transient to the map entry + backward_state.add_edge(transient_node, None, map_entry, None, dace.Memlet()) + if backward_node not in self.zeroed_out: + self.zeroed_out[backward_node] = [transient_node] + else: + self.zeroed_out[backward_node].append(transient_node) + else: + raise AutoDiffException("Unsupported data descriptor {}".format(array_desc)) + + def _remove_onnx_attribute_accessnodes(self, nodes_list: List[nodes.Node], state: SDFGState) -> None: + """Remove ONNX attribute AccessNodes that don't need gradient tracking. + + For some ONNX operators, nodes have attributes as input connectors even if the inputs are actually constant. + Examples of such attributes are `axis` and `keepdims` in `ReduceSum`. + Gradients for these attributes should not be tracked since they represent control flow and not data flow. + """ + attribute_to_remove = {"axis", "keepdims", "axes", "p", "dilations", "kernel_shape", "strides"} + for node in nodes_list[:]: # Iterate over a copy of the list to avoid modification issues + if isinstance(node, nodes.AccessNode): + out_edges = state.out_edges(node) + if out_edges and all( + ONNX_AVAILABLE and isinstance(edge.dst, ONNXOp) and edge.dst_conn in attribute_to_remove + for edge in out_edges): + nodes_list.remove(node) + + def _remove_maps_without_input_connectors(self, nodes_list: List[nodes.Node], state: SDFGState) -> None: + """Remove maps that don't have any input connectors from the nodes_list. + + These are maps that won't have an output in the backward pass and thus can be skipped from the reversal process. + Note that we do not remove the AccessNode that the no-input map writes to. + This is because we might need to zero out the gradient of this node. + If no zeroing out is necessary, the node will be removed in the reverse_subgraph function cleanup at the end. + """ + for node in nodes_list[:]: # Iterate over a copy of the list to avoid modification issues + if isinstance(node, nodes.MapEntry) and len(node.in_connectors) == 0: + nodes_list.remove(node) + # Remove the MapExit and everything in between + # Get the equivalent map exit for the map entry + map_exit = state.exit_node(node) + nodes_list.remove(map_exit) + + # Get all the nodes between the map entry and exit + for state_node in state.nodes(): + # Check the scope of the node if it is within the map + if state_node in state.scope_dict() and state.scope_dict( + )[state_node] == node and state_node in nodes_list: + nodes_list.remove(state_node) + + def _find_subgraph_to_differentiate(self) -> None: + """Determine which nodes we need to reverse; this forms the subgraph we will differentiate. + + We do a reverse BFS from the target output node. + In the case where a state is within a loop, this may result in different subgraphs + depending on the loop iteration. + + To calculate the gradients for a node x in ``required_gradients``, we need to sum up the gradient + contributions from every node y where x is used as an input. + """ + backward_nodes: set[nodes.Node] = set() + given_gradients_all_states = set(self.given_gradients_data) + + required_gradients_all_states = {n for n in self.required_gradients_data} + given_gradients_all_states = given_gradients_all_states | required_gradients_all_states + + # Do the backward BFS iteratively + for state in reversed(self.state_order): + state_given_gradients: List[nodes.AccessNode] = [] + + for node in state: + if isinstance(node, nodes.AccessNode) and node.data in given_gradients_all_states: + state_given_gradients.append(node) + + backward_nodes = {n for e in state.edge_bfs(state_given_gradients, reverse=True) for n in [e.src, e.dst]} + nodes_list = list(backward_nodes) + + # Clean up unwanted elements + self._remove_maps_without_input_connectors(nodes_list, state) + self._remove_onnx_attribute_accessnodes(nodes_list, state) + + state_subgraph = dstate.StateSubgraphView(state, nodes_list) + + state_subgraph = self._add_missing_nested_sdfg_connectors_to_view(state=state, + state_subgraph=state_subgraph, + view_nodes=nodes_list) + + # Add mapping + self.states_view_map[state] = state_subgraph + + # In the case where this state is within a for loop + within_loop, _ = ad_utils.state_within_loop(state) + if within_loop: + # Other elements that are not within state_subgraph will need to be reversed + # We create a separate mapping for these elements + + # Get all the access nodes that are used in the previous view + subgraph_an = [node.data for node in state_subgraph.nodes() if isinstance(node, nodes.AccessNode)] + + # For each access node in this view + for state_node in state: + if isinstance(state_node, nodes.AccessNode) and state_node.data in subgraph_an: + state_given_gradients.append(state_node) + + # Do reverse BFS starting from this new set of nodes + backward_nodes = { + n + for e in state.edge_bfs(state_given_gradients, reverse=True) + for n in [e.src, e.dst] + } + + view_nodes = list(backward_nodes) + self._remove_maps_without_input_connectors(nodes_list, state) + + loop_state_subgraph = dstate.StateSubgraphView(state, view_nodes) + + loop_state_subgraph = self._add_missing_nested_sdfg_connectors_to_view( + state=state, state_subgraph=loop_state_subgraph, view_nodes=view_nodes) + + # If the two views are different + # Here we only check if the number of nodes is the same + # Since states_view_map[state] is a subset of loop_states_view_map[state] + if len(state_subgraph) != len(loop_state_subgraph): + self.loop_states_view_map[state] = loop_state_subgraph + + # Update the list of given gradients to use for states + for node in backward_nodes: + if isinstance(node, nodes.AccessNode) and node.data not in given_gradients_all_states: + # We want all of the backward AccessNodes that made it to the intersection + given_gradients_all_states.add(node.data) + + def array_grad_name(self, forward_name: str) -> str: + """Return the gradient name of a name from the forward pass.""" + if forward_name not in self.array_grad_map: + self.array_grad_map[forward_name] = \ + self.backward_sdfg._find_new_name("gradient_" + forward_name) + + return self.array_grad_map[forward_name] + + def _add_gradient_data_descriptor(self, data_name: str) -> dt.Array: + """Add the data descriptor for the gradient for `data_name`. + + :param data_name: The name of the forward descriptor. + """ + grad_name = self.array_grad_name(data_name) + + if grad_name in self.backward_sdfg.arrays: + raise AutoDiffException(f"descriptor for gradient of {data_name} ({grad_name}) already exists") + + array = self.sdfg.arrays[data_name] + + if not isinstance(array, (dt.Scalar, dt.Array, dt.View)): + raise AutoDiffException("Unsupported data descriptor {}".format(array)) + + cloned_datadesc = copy.deepcopy(array) + + # only the grads of the inputs and the outputs are not transient + cloned_datadesc.transient = data_name not in self.input_names and data_name not in self.output_names + + # Store references + self.backward_grad_arrays[grad_name] = cloned_datadesc + self.backward_sdfg.arrays[grad_name] = cloned_datadesc + return cloned_datadesc + + def _reverse_loop_conditional(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the conditional for the reversed version of this loop.""" + + # Get the loop iterator + it = loop.loop_variable + + # Get the loop start + start, _ = ad_utils.extract_loop_region_info(loop) + + # Get the stride sign + stride_sign = ad_utils.get_stride_sign(loop) + + # Reverse the conditional to end at the start of the original loop + # This will be incremented or decremented depending on the stride + if stride_sign > 0: + reversed_condition = f"{it} > {start}-1" + else: + reversed_condition = f"{it} < {start}+1" + + return reversed_condition + + def _reverse_loop_initial_statement(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the initialization statement for the reversed version of this loop.""" + # Get the loop iterator + it = loop.loop_variable + + stride_sign = ad_utils.get_stride_sign(loop) + + # Get the loop end + _, end = ad_utils.extract_loop_region_info(loop) + + # Reverse the initialization to start from the end of the forward loop + # This will be incremented or decremented depending on the stride + if stride_sign > 0: + init_expr = f"{it} = {end}-1" + else: + init_expr = f"{it} = {end}+1" + + return init_expr + + def _reverse_loop_update_statement(self, loop: LoopRegion) -> str: + """Given a loop region as a parameter, create the update statement for the reversed version of this loop.""" + + # Get the original update statement + fwd_update = loop.update_statement.as_string + + stride_sign = ad_utils.get_stride_sign(loop) + + # If the stride is positive + if stride_sign > 0: + update_statement = fwd_update.replace("+", "-") + else: + # If the stride is negative + update_statement = fwd_update.replace("-", "+") + + return update_statement + + def _match_loop_region(self, fwd_loop: LoopRegion) -> LoopRegion: + """Create the backward LoopRegion and fill it with the reversal of the forward LoopRegion.""" + + init_expr = self._reverse_loop_initial_statement(fwd_loop) + reversed_condition = self._reverse_loop_conditional(fwd_loop) + update_statement = self._reverse_loop_update_statement(fwd_loop) + + # Create the label + reversed_label = f"{fwd_loop.label}_reversed" + + # Create the loop object and return it + reversed_loop = LoopRegion(label=reversed_label, + initialize_expr=init_expr, + condition_expr=reversed_condition, + update_expr=update_statement, + loop_var=fwd_loop.loop_variable) + + return reversed_loop + + def _reverse_loop_region(self, loop: LoopRegion): + """Given a LoopRegion as a parameter, reverse it, add the loop states that belong in this region.""" + + # Create the reversed loop region + reversed_loop = self._match_loop_region(fwd_loop=loop) + self.reversed_loops_map[loop] = reversed_loop + + # Add the reversed loop directly + parent_graph = self._get_reversed_parent_graph(loop) + parent_graph.add_node(reversed_loop) + + # Add all the loop nodes to the graph and recursivly reverse child loop regions + for node in loop.nodes(): + if isinstance(node, LoopRegion): + + # This node shouldn't be reversed already since we're going top-down + if node in self.reversed_loops_map: + raise AutoDiffException(f"Loop {node} has already been reversed") + self._reverse_loop_region(node) + elif isinstance(node, SDFGState): + + # Get the backward_node + bwd_node = self.reversed_states_map[node] + + # Remove from the backward SDFG + self.backward_sdfg.remove_node(bwd_node) + + # Add it to the loop region + reversed_loop.add_node(bwd_node) + + # Also add loop states if any + if node in self.reversed_loop_states_map: + # Get the backward_node + bwd_node = self.reversed_loop_states_map[node] + + def _add_missing_nested_sdfg_connectors_to_view(self, state: SDFGState, state_subgraph: dstate.StateSubgraphView, + view_nodes: List[nodes.Node]): + """Add missing NestedSDFG connectors to the view for correctness. + + There is a special case for NestedSDFGs that we need to fix + in the case where a NestedSDFG has an inout connector, + but we only care about one of those connectors for the sake of AD. + We need to add the missing connector for correctness. + TODO: This is only a problem if the said connector is written to + inside the NestedSDFG. + """ + # In the case where a NestedSDFG has an inout connector, + # but we only care about one of those connectors for the sake of AD + # we need to add the missing connector for correctness + # TODO: this is only a problem if the said connector is written to + # inside the NestedSDFG + # Iterate over the nested SDFGs in the view + for g in state_subgraph.nodes(): + if isinstance(g, nodes.NestedSDFG): + + inout_connectors = set(g.in_connectors).intersection(set(g.out_connectors)) + # If there are any inout connectors + if len(inout_connectors) > 0: + out_connectors = {edge.src_conn: edge for edge in state.out_edges(g)} + in_connectors = {edge.dst_conn: edge for edge in state.in_edges(g)} + view_out_connectors = {edge.src_conn: edge for edge in state_subgraph.out_edges(g)} + view_in_connectors = {edge.dst_conn: edge for edge in state_subgraph.in_edges(g)} + for con in inout_connectors: + # Check if it is missing in the out or in connectors of the view + if con in view_out_connectors and con not in view_in_connectors: + # Get the equivalent in node and connector + edge = in_connectors[con] + if not isinstance(edge.src, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as source, got {type(edge.src)}") + view_nodes.append(edge.src) + if con not in view_out_connectors and con in view_in_connectors: + # Add the corresponding edge to the view + edge = out_connectors[con] + if not isinstance(edge.dst, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as destination, got {type(edge.dst)}") + view_nodes.append(edge.dst) + + return dstate.StateSubgraphView(state, view_nodes) + + def _compare_memlet_accesses_to_array_size(self, data_name: str, memlet: Memlet) -> int: + """Compare the memlet range with the size of the array to see if the array is being overwritten.""" + total_size = self.backward_sdfg.arrays[data_name].total_size + try: + if total_size > memlet.num_accesses: + return 1 + elif memlet.num_accesses == total_size: + return 0 + + # Something is wrong here raise an exception + raise AutoDiffException(f"Memlet {memlet} has more accesses than the size of the data {data_name}") + + # If the comparison can not be made, return None + except TypeError: + return None + + def _get_reversed_parent_graph(self, forward_node: nodes.Node): + """Given a node in the SDFG, get the reversed parent of this node.""" + fwd_parent_graph = forward_node.parent_graph + + if fwd_parent_graph == self.sdfg: + parent_graph = self.backward_sdfg + elif isinstance(fwd_parent_graph, SDFGState): + parent_graph = self.reversed_states_map[fwd_parent_graph] + elif isinstance(fwd_parent_graph, LoopRegion): + parent_graph = self.reversed_loops_map[fwd_parent_graph] + + return parent_graph + + def _get_backward_loop_state_edge(self, forward_edge: dace.InterstateEdge) -> dace.InterstateEdge: + """Given an edge from the forward pass, return the equivalent edge in the backward pass.""" + # Get the source and destination states + forward_src = forward_edge.src + forward_dst = forward_edge.dst + + if isinstance(forward_src, LoopRegion): + fwd_src_is_loop = True + if forward_src not in self.reversed_loops_map: + raise AutoDiffException(f"Forward loop {forward_src} not found in reversed loops map") + else: + fwd_src_is_loop = False + if forward_src not in self.reversed_states_map: + raise AutoDiffException(f"Forward state {forward_src} not found in reversed states map") + + if isinstance(forward_dst, LoopRegion): + fwd_dst_is_loop = True + if forward_dst not in self.reversed_loops_map: + raise AutoDiffException(f"Forward loop {forward_dst} not found in reversed loops map") + else: + fwd_dst_is_loop = False + if forward_dst not in self.reversed_states_map: + raise AutoDiffException(f"Forward state {forward_dst} not found in reversed states map") + + # Note that the source will become the destination + backward_dst = self.reversed_states_map[forward_src] if not fwd_src_is_loop else self.reversed_loops_map[ + forward_src] + backward_src = self.reversed_states_map[forward_dst] if not fwd_dst_is_loop else self.reversed_loops_map[ + forward_dst] + + # Each one of these in edges needs to have an equivalent + # out edge in the backward part of the SDFG + bwd_edge = None + connection_state = backward_dst + + # Find the equivalent edge in the backward SDFG + for b_edge in connection_state.parent_graph.in_edges(connection_state): + if b_edge.src == backward_src: + bwd_edge = b_edge + break + + if not bwd_edge: + raise AutoDiffException(f"Can't find the equivalent edge of {forward_edge} in the backward pass") + + return bwd_edge + + def _get_backward_state_edge(self, forward_edge: dace.InterstateEdge) -> dace.InterstateEdge: + """Given an edge from the forward pass, return the equivalent edge in the backward pass.""" + # Get the source and destination states + forward_state_src = forward_edge.src + forward_state_dst = forward_edge.dst + + # Get the equivalent states in the backward pass + if (forward_state_src not in self.reversed_states_map and forward_state_src not in self.reversed_loops_map): + raise AutoDiffException(f"Forward state source {forward_state_src} not found in reversed maps") + if (forward_state_dst not in self.reversed_states_map and forward_state_src not in self.reversed_loops_map): + raise AutoDiffException(f"Forward state destination {forward_state_dst} not found in reversed maps") + + # Note that the src will become the destination + backward_state_dst = self.reversed_states_map[ + forward_state_src] if forward_state_src in self.reversed_states_map else self.reversed_loops_map[ + forward_state_src] + backward_state_src = self.reversed_states_map[ + forward_state_dst] if forward_state_dst in self.reversed_states_map else self.reversed_loops_map[ + forward_state_dst] + + # Each one of these in edges needs to have an equivalent + # out edge in the backward part of the SDFG + bwd_edge = None + connection_state = backward_state_dst + + # Find the equivalent edge in the backward SDFG + for b_edge in connection_state.parent_graph.in_edges(connection_state): + if b_edge.src == backward_state_src: + bwd_edge = b_edge + break + + if not bwd_edge: + raise AutoDiffException(f"Can't find the equivalent edge of {forward_edge} in the backward pass") + + return bwd_edge + + def _str_to_access(self, data: str, source: str) -> nodes.AccessNode: + """Given a string containing the name of the accessed array, return the AccessNode in the state. + + Given a string containing the name of the accessed array, return the AccessNode in the state + that points to this array. + If there are multiple AccessNodes, the behavior will depend on whether we want + an output or input AccessNode. + Input: We will return the first occurrence of this node in the state and make sure there are + only outgoing edges from this node. + Output: We will return the last occurrence of this node in the state + where the node only has incoming edges. + """ + matches = [(node, state) for state in self.sdfg.states() for node in state.nodes() + if isinstance(node, nodes.AccessNode) and node.data == data] + # Unused in model + if len(matches) == 0: + return None + + # there is only a single AccessNode with this name + if len(matches) == 1: + return matches[0][0] + + # len(matches) > 1 + else: + # There are multiple occurrences of the same AccessNode + if source == "inputs": + # We return the first node with this data + input_node: nodes.AccessNode = matches[0][0] + return input_node + + if source == "outputs": + # Go through the list of matches in reverse + for output_node, output_node_state in reversed(matches): + # We want the first node that has at least one incoming edge to it + # This represents the last time the output data was modified + in_edges = output_node_state.in_edges(output_node) + if len(in_edges) > 0: + return output_node + + raise AutoDiffException( + f"The specified output {data} was not written to by any AccessNode in this state") + + raise AutoDiffException(f"There are multiple nodes with data {data} " + f" but the source (inputs or outputs) was not specified correctly") + + def _expand_nodes(self) -> bool: + """Expand all library nodes in the sdfg to pure implementations. + + Returns whether something was expanded. + """ + expanded_something = False + for state_view in self.states_view_map.values(): + for node, parent_graph in state_view.all_nodes_recursive(): + if isinstance(parent_graph, dstate.StateSubgraphView): + parent_graph = parent_graph.graph + + # Check if the node exists in the backward implementation repository + if find_backward_implementation(parent_graph.parent_graph, parent_graph, node) is not None: + continue + + # Only check others if we didn't break out of the above loop + if ONNX_AVAILABLE and isinstance(node, ONNXOp): + impls = ONNXForward.registered_implementations(node.schema.name) + + # Order the implementations so that implementations containing "pure" are tried first + impls = [i for name, i in impls if "pure" in name] + [i for name, i in impls if "pure" not in name] + for impl in impls: + if impl.forward_can_be_applied(node, parent_graph, self.sdfg): + # Configure the module-level expansion class + ExpansionTemplate.environments = impl.environments if hasattr(impl, "environments") else [] + ExpansionTemplate._impl = impl + ExpansionTemplate._match_node = xf.PatternNode(type(node)) + ExpansionTemplate.apply_to(parent_graph.parent, verify=False, _match_node=node) + expanded_something = True + break + + # This could later on be changed to check if the expansion is differentiable and if not, move + # on to the next expansion. For now we will just apply the first one that matches, prioritizing ones that + # have "pure" in the name + if isinstance(node, nodes.LibraryNode) and not (ONNX_AVAILABLE and isinstance(node, ONNXOp)): + # Try to select an expansion + if hasattr(node, "implementations"): + implementations = node.implementations + + pure_candidates = [name for name, _ in sorted(implementations.items()) if "pure" in name] + if len(pure_candidates) > 0: + expansion = pure_candidates[0] + else: + expansion = node.implementation + else: + expansion = node.implementation + + node.implementation = expansion + node.expand(parent_graph.parent, parent_graph) + expanded_something = True + + return expanded_something + + def _get_node_state(self, node: nodes.Node) -> SDFGState: + """Return the SDFG state that contains this node.""" + matches = [] + for state in self.sdfg.states(): + if node in state.nodes(): + matches.append(state) + + if len(matches) != 1: + raise AutoDiffException(f"Expected exactly one match, got {len(matches)}") + return matches[0] + + def _connect_conditional_map_exist(self, forward_state: SDFGState, backward_state: SDFGState, + backward_map_exit: nodes.MapExit, fwd_tasklet: nodes.Tasklet): + """Connect the map exit of a conditional tasklet to a new access node which will zero out the gradient. + """ + + if len(backward_map_exit.in_connectors) != 0: + raise AutoDiffException( + f"Expected no input connectors on backward map exit, got {len(backward_map_exit.in_connectors)}") + + # Add the in and out connectors for the zero-out operation + backward_map_exit.add_in_connector("IN_zero_out") + backward_map_exit.add_out_connector("OUT_zero_out") + + # Get the memlet data for the edge from the tasklet to the map exist + tasklet_out_edge = forward_state.out_edges(fwd_tasklet) + if len(tasklet_out_edge) != 1: + raise AutoDiffException(f"Expected exactly one tasklet output edge, got {len(tasklet_out_edge)}") + tasklet_out_edge = tasklet_out_edge[0] + tasklet_memlet_path = forward_state.memlet_path(tasklet_out_edge) + if len(tasklet_memlet_path) != 2: + raise AutoDiffException(f"Expected tasklet memlet path of length 2, got {len(tasklet_memlet_path)}") + + # Copy the memlet and change the data name + memlet_data = copy.deepcopy(tasklet_memlet_path[0].data) + memlet_data.data = self.array_grad_map[memlet_data.data] + + # Get the reversed tasklet + bwd_tasklet = self.reverse_map[fwd_tasklet] + + # Connect this map exist to the tasklet + backward_state.add_edge(bwd_tasklet, "__zero_out_conn__", backward_map_exit, "IN_zero_out", memlet_data) + + # Replicate the target accedd node and connect it + fwd_target_an: nodes.AccessNode = tasklet_memlet_path[-1].dst + if not isinstance(fwd_target_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode for forward target, got {type(fwd_target_an)}") + if fwd_target_an not in self.reverse_map: + raise AutoDiffException(f"Forward target AccessNode {fwd_target_an} not found in reverse map") + bwd_target_an = self.reverse_map[fwd_target_an] + + replicated_bwd_target_an = copy.deepcopy(bwd_target_an) + backward_state.add_node(replicated_bwd_target_an) + + an_memlet_data: nodes.AccessNode = copy.deepcopy(tasklet_memlet_path[1].data) + an_memlet_data.data = self.array_grad_map[an_memlet_data.data] + backward_state.add_edge(backward_map_exit, "OUT_zero_out", replicated_bwd_target_an, None, an_memlet_data) + + # We need to get the map entry that starts the conditional block + # First get the conditional tasklet + conditional_block = self._extract_conditional_array_assignment_block( + forward_state=forward_state, tasklet_node=fwd_tasklet, subgraph=self.states_view_map[forward_state]) + # Get the map entry of the conditional bloc + map_entries = [n for n in conditional_block if isinstance(n, nodes.MapEntry)] + + if len(map_entries) != 1: + raise AutoDiffException( + f"Expected a single MapEntry node in the conditional block, found {len(map_entries)}") + else: + map_entry = map_entries[0] + + # Add the new access node to a dictionary in case it needs to be connected + self.conditional_block_entry[map_entry] = replicated_bwd_target_an + + def _conditional_tasklet(self, tasklet_node: nodes.Tasklet): + """Check if this tasklet contains a conditional. + + This only happens in conditional array assignments and requires special treatment in reversing the graph. + """ + # sanity check + if not isinstance(tasklet_node, nodes.Tasklet): + raise AutoDiffException(f"Expected Tasklet node, got {type(tasklet_node)}") + + # get the code string and check if there is an if + # TODO: How to more accurately check this? + return "if" in tasklet_node.code.as_string + + def _conditional_nested_sdfg(self, forward_state: SDFGState, node: nodes.NestedSDFG): + """Check if this NestedSDFG contains a conditional. + + This only happens in conditional array assignments and requires special treatment in reversing the graph. + """ + # sanity check + if not isinstance(node, nodes.NestedSDFG): + raise AutoDiffException(f"Expected NestedSDFG node, got {type(node)}") + + # get the incoming edges to the sdfg + in_edges = forward_state.in_edges(node) + + # check if any of the incoming edges are boolean edges + for edge in in_edges: + if self.sdfg.arrays[edge.data.data].dtype == dace.bool: + return True + + # get the code string and check if there is an if + return False + + def _extract_conditional_array_assignment_block(self, forward_state: SDFGState, tasklet_node: nodes.Node, + subgraph: dstate.SubgraphView): + """Extract a conditional array assignment block. + + Given a conditional tasklet, check if this is a conditional array assignment of the type + A[A>=0 and A<=5] = cst. At the moment the function only supports constant assignments. + """ + try: + + if not isinstance(tasklet_node, nodes.Tasklet): + raise AutoDiffException(f"Expected Tasklet node, got {type(tasklet_node)}") + # This applies to both Tasklet and NestedSDFG nodes + # get the AccessNode containing the boolean values for this assignment + tasklet_in_edges = forward_state.in_edges(tasklet_node) + tasklet_boolean_edge = None + single_boolean_edge_found = False + for edge in tasklet_in_edges: + edge_type = self.sdfg.arrays[edge.data.data].dtype + if edge_type == dace.bool: + # sanity check + if single_boolean_edge_found: + # we expect there to be a single AccessNode where the booleans come from + raise AutoDiffException( + "Multiple boolean edges found for conditional assignment. Expected only one.") + tasklet_boolean_edge = edge + single_boolean_edge_found = True + + if tasklet_boolean_edge is None: + raise AutoDiffException("Expected to find a boolean edge for conditional assignment") + tasklet_in_memlet_path = forward_state.memlet_path(tasklet_boolean_edge) + # the first element in the path is the boolean AN + bools_an = tasklet_in_memlet_path[0].src + if not isinstance(bools_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode for boolean values, got {type(bools_an)}") + + # save all the nodes in the path to the assignment block list + conditional_assingement_block_nodes = { + n + for e in forward_state.edge_bfs(bools_an, reverse=True) + for n in [e.src, e.dst] + } + + # if any of the nodes in the block are required for gradient tracking + nodes_to_keep_tracking: set[nodes.Node] = self._get_gradient_nodes_to_track( + forward_state=forward_state, block_nodes=conditional_assingement_block_nodes, subgraph=subgraph) + for node in nodes_to_keep_tracking: + # we get the reverse bfs of this node and remove it from block nodes to avoid skipping these nodes + node_subgraph = {n for e in forward_state.edge_bfs(node, reverse=True) for n in [e.src, e.dst]} + + # add the node itself + node_subgraph.add(node) + conditional_assingement_block_nodes = conditional_assingement_block_nodes.difference(node_subgraph) + + except Exception as e: + # if this is not the structure we are expecting, fail + raise AutoDiffException(f"The boolean datatype in edges is limited to conditional array assingements." + f" This stucture is not supported.") from e + + return conditional_assingement_block_nodes + + def _get_gradient_nodes_to_track(self, forward_state: SDFGState, block_nodes: List[nodes.Node], + subgraph: dstate.SubgraphView): + """Get gradient nodes that need tracking in conditional assignments. + + When extracting the block for a conditional assignment, we need to make sure we keep tracking + the required gradient AccessNodes. + This function checks all the required access nodes that are in the conditional block. + At the moment this is just the target access node. + """ + nodes_to_track: List[nodes.AccessNode] = [] + gradient_nodes = [n for n in self.required_gradients_data] + gradient_nodes += [n for n in self.given_gradients_data] + + # get the subgraph difference + difference = set(subgraph.nodes()).difference(set(block_nodes)) + + # go through all the access nodes in the conditional block + for node in block_nodes: + if not isinstance(node, nodes.AccessNode): + continue + + # we always want to track the gradient nodes + if node.data in gradient_nodes: + nodes_to_track.append(node) + continue + # if this access node has multiple edges and any of them are outside the block + + node_out_edges = forward_state.out_edges(node) + if len(node_out_edges) > 1: + for edge in node_out_edges: + if edge.dst in difference: + nodes_to_track.append(node) + data = node.data + + # search for this array in the graph difference + for d_node in difference: + if not isinstance(d_node, nodes.AccessNode): + continue + if d_node.data == data: + nodes_to_track.append(node) + return nodes_to_track + + def _reverse_subgraph(self, forward_state: SDFGState, backward_state: SDFGState, + subgraph: dstate.StateSubgraphView) -> None: + """Reverse a given subgraph by reversing all nodes within it. + + :param forward_state: The forward state containing the subgraph. + :param backward_state: The backward state to add reversed nodes to. + :param subgraph: The subgraph view containing nodes to reverse. + """ + + # Conditional assignment nodes + conditional_assignment_nodes: List[nodes.Node] = [] + + # A reversed topological sort is a topological sort on the reverse graph + for node in reversed(list(dace_utils.dfs_topological_sort(subgraph, subgraph.source_nodes()))): + + try: + # If this node is a part of the conditional assignment block, we skip it + if node in conditional_assignment_nodes: + continue + + # Output names on the forward node + # (for which the gradient will be connected as an input on the reverse node) + given_gradients = [ + edge.src_conn for edge in subgraph.out_edges(node) + if ad_utils.path_src_node_in_subgraph(edge, subgraph) + ] + + # Input names on the forward node that gradients should be generated for + # note that the edge for the conditional is not included + required_gradients = [ + edge.dst_conn for edge in subgraph.in_edges(node) + if ad_utils.path_src_node_in_subgraph(edge, subgraph) + and self.sdfg.arrays[edge.data.data].dtype != dace.bool + ] + + reversed_node, backward_result = self._get_reverse_node(forward_state, backward_state, node, + given_gradients, required_gradients) + + self.reverse_map[node] = reversed_node + self.result_map[node] = backward_result + + # Connect the required inputs of the reverse node: + # the gradients ... + self._connect_given_gradients(forward_state=forward_state, + backward_state=backward_state, + subgraph=subgraph, + forward_node=node) + + # ... and any required input values from the forward pass + #################################### + # Determine which forward inputs we need to connect. + # these are the in_connectors on the reverse node, minus what has already been connected. + already_connected = {e.dst_conn for e in backward_state.in_edges(reversed_node)} + required_inputs = set(reversed_node.in_connectors).difference(already_connected) + required_inputs = {c: c for c in required_inputs} + self._connect_forward_inputs(forward_state, backward_state, node, reversed_node, required_inputs) + + if isinstance(node, nodes.AccessNode): + + # this means we are writing out a grad to an array. + # initialize the gradient if it hasn't been initialized already (this can also happen in + # _connect_given_gradients + array_grad_name = self.array_grad_name(node.data) + if array_grad_name not in self.backward_sdfg.arrays: + # this grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(node.data) + + # we need to make all incoming gradients sum + if backward_state.in_degree(reversed_node) > 1: + + # Add a wcr to all the writes to the AccessNode + for edge in backward_state.in_edges(reversed_node): + # Add wcr to the memlet + for tree_edge in backward_state.memlet_tree(edge): + tree_edge.data.wcr = "lambda x, y: x + y" + + # If this node is a tasklet with a condition, we add some modification to the backward state + elif (isinstance(node, nodes.Tasklet) + and self._conditional_tasklet(node)) or (isinstance(node, nodes.NestedSDFG) + and self._conditional_nested_sdfg(forward_state, node)): + # extract the conditional assignment block or fail if this is an unexpected structure + conditional_block = self._extract_conditional_array_assignment_block(forward_state=forward_state, + tasklet_node=node, + subgraph=subgraph) + + # add these nodes to be skipped in the future + conditional_assignment_nodes.extend(conditional_block) + + # If the node is an AccessNode and it is being overwritten in the forward pass, + # we need to zero-out the gradients of the overwritten values + if isinstance(node, nodes.AccessNode): + # Check if there is an incoming edge to this node + incoming_edges = forward_state.in_edges(node) + + # If there is an incoming edge, we need to zero-out the gradient + for edge in incoming_edges: + + # Check, if possible, if the written subset is not zero + write_size = edge.data.subset.num_elements() + + # Check if the node doesn't have a WCR + # If it does, this is not an overwrite and the gradients should not be cleared + has_wcr = edge.data.wcr is not None + + # Check if the edge is dynamic, this means not all values are overwritten + # We will skip zeroing out the gradient in this case + if edge.data.dynamic: + Warning("Dynamic memlets are not fully supported in the reverse pass. " + "The gradient of the overwritten values may not be zeroed out.") + if not has_wcr and not edge.data.dynamic: + # Determine if we need to zero out the gradient + zero_out = not (isinstance(write_size, int) and write_size == 0) + + # We need to zero out the same memlet accesses in the backward pass + if zero_out: + self._zero_out_gradient(forward_state=forward_state, + forward_node=node, + memlet=edge.data) + + # Cleanup of isolated nodes + # We will have an isolated node if it is not connected to any other node in the state view + # And it has not been cleared out if it is an AccessNode + # Isolated nodes should only appear from clearing out gradients + # Check if this is an isolated node and remove it if it is + if backward_state.out_degree(reversed_node) == 0 and backward_state.in_degree(reversed_node) == 0: + if isinstance(node, nodes.AccessNode) and node not in self.zeroed_out: + backward_state.remove_node(reversed_node) + + except AutoDiffException as e: + raise AutoDiffException("Failed at node {}: {}".format(node, str(e))) from e + + def _set_wcr_if_needed(self, backward_state: SDFGState, backward_node: nodes.Node, + edge: dstate.MultiConnectorEdge) -> None: + """Set write-conflict resolution (WCR) for gradient accumulation if needed. + + If this AccessNode represents a gradient that has already been used elsewhere, + we want to accumulate the gradients rather than overwrite them. + + :param backward_state: The backward state containing the edge. + :param backward_node: The backward node (should be AccessNode for gradients). + :param edge: The edge that may need WCR for gradient accumulation. + """ + + # Check if the forward node is an AccessNode + if not isinstance(backward_node, nodes.AccessNode): + return + + # Otherwise, we add up the gradients, not overwrite them + for tree_edge in backward_state.memlet_tree(edge): + tree_edge.data.wcr = "lambda x, y: x + y" + + def _connect_given_gradients(self, forward_state: SDFGState, backward_state: SDFGState, + subgraph: dstate.StateSubgraphView, forward_node: nodes.Node) -> Optional[SDFGState]: + """Connect output gradients of forward_node as inputs to the corresponding reverse node. + + :param forward_state: The forward state containing the node. + :param backward_state: The backward state to add connections to. + :param subgraph: The subgraph view for the current operation. + :param forward_node: The forward node whose output gradients to connect. + :return: The backward state (possibly modified) or None. + """ + new_backward_state = None + # First, create the data descriptor if this is an access node and it hasn't been added before + if isinstance(forward_node, nodes.AccessNode): + grad_name = self.array_grad_name(forward_node.data) + if grad_name not in self.backward_sdfg.arrays: + # This grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(forward_node.data) + + for edge in subgraph.out_edges(forward_node): + if not ad_utils.path_src_node_in_subgraph(edge, subgraph) or edge.dst not in self.reverse_map: + if edge.dst in self.conditional_block_entry: + backward_node = self.reverse_map[edge.src] + if not isinstance(edge.dst, nodes.MapEntry): + raise AutoDiffException(f"Expected MapEntry in conditional block, got {type(edge.dst)}") + conditional_zero_out_an = self.conditional_block_entry[edge.dst] + # Add an empty edge to skip the conditional block + backward_state.add_edge(conditional_zero_out_an, None, backward_node, None, Memlet()) + # skip connecting edges for which we don't need to generate grads. + continue + + # Skip connecting boolean edges + if self.sdfg.arrays[edge.data.data].dtype == dace.bool: + # we also need to remove this connector otherwise it will be dangling + backward_node = self.reverse_map[edge.src] + if not (isinstance(backward_node, nodes.MapEntry) or isinstance(backward_node, nodes.MapExit)): + # If this is not a map entry or exit, the boolean gradients will not be added + # No need to remove the connector in this case + continue + + conn_to_remove = ad_utils.invert_map_connector(edge.src_conn) + assert conn_to_remove in backward_node.in_connectors + assert backward_node.remove_in_connector(conn_to_remove) + if len(backward_node.in_connectors) == 0: + self._connect_conditional_map_exist(forward_state=forward_state, + backward_state=backward_state, + backward_map_exit=backward_node, + fwd_tasklet=edge.dst) + continue + + _, output_conn, dest_node, input_conn, fwd_memlet = edge + + memlet = copy.deepcopy(fwd_memlet) + + # Remove the WCR since these are now read edges + memlet.wcr = None + + grad_name = self.array_grad_name(memlet.data) + if grad_name not in self.backward_sdfg.arrays: + # This grad hasn't been written before: initialize it + self._add_gradient_data_descriptor(memlet.data) + + # We should not rely on the memlet data because that depends on the subset and other subset attibutes + # If this is an access node, and the memlet data is not the same as the AN data + memlet.data = grad_name + + # Check of the values have been zeroed out + backward_dst_node = self.reverse_map[dest_node] + if backward_dst_node in self.zeroed_out: + # The values will be zeroed out in the backward node + # We use the transient array instead + copied_zeroed_nodes = self.zeroed_out[backward_dst_node] + if len(copied_zeroed_nodes) == 1: + backward_dst_node = copied_zeroed_nodes[0] + else: + for node in copied_zeroed_nodes: + # Get the memlet to this node + zero_in_dege = backward_state.in_edges(node) + assert len(zero_in_dege) == 1 + zeroed_memlet = zero_in_dege[0].data + if zeroed_memlet.subset == edge.data.subset: + backward_dst_node = node + break + + memlet.data = backward_dst_node.data + + # We also need to Add an empty edge from the cleared node to where the data will be used + tmp_clear_node_out_edges = backward_state.out_edges(backward_dst_node) + for e in tmp_clear_node_out_edges: + if e.data.data is None and e.data.subset is None and e.data.other_subset is None: + clearing_map_entry = e.dst + assert isinstance(clearing_map_entry, nodes.MapEntry) + clearing_map_exit = backward_state.exit_node(clearing_map_entry) + assert isinstance(clearing_map_exit, nodes.MapExit) + # Check that this only has a single output edge and get the destination + assert backward_state.out_degree(clearing_map_exit) == 1 + cleared_out_node = backward_state.out_edges(clearing_map_exit)[0].dst + backward_node = self.reverse_map[forward_node] + backward_state.add_edge(cleared_out_node, None, backward_node, None, dace.Memlet()) + + # If this is a connection between two access nodes we need to flip the memlet subsets + if isinstance(forward_node, nodes.AccessNode): + # Special case for when the two access nodes are the same + if forward_node.data == dest_node.data and fwd_memlet.other_subset is not None: + new_memlet = dace.Memlet(data=self.reverse_map[forward_node].data, + subset=fwd_memlet.other_subset, + other_subset=fwd_memlet.subset) + else: + new_memlet = dace.Memlet(data=self.reverse_map[forward_node].data, + subset=fwd_memlet.subset + if fwd_memlet.data == forward_node.data else fwd_memlet.other_subset, + other_subset=fwd_memlet.other_subset + if fwd_memlet.data == forward_node.data else fwd_memlet.subset) + memlet = new_memlet + if input_conn not in self.result_map[dest_node].required_grad_names: + continue + new_edge = backward_state.add_edge( + backward_dst_node, + self._lookup_required_grad_name(dest_node, input_conn), + self.reverse_map[forward_node], + self._lookup_given_grad_name(forward_node, output_conn), + memlet, + ) + + # Change the access data in the memlet path if it has been zeroed out + # Calling the memlet path while reversing will raise an error + # Because the map has not been completely added for the backward state yet + # We also don't need to do anything for an AccessNode -> AccessNode connection + if (not isinstance(forward_node, + (nodes.MapExit, nodes.MapEntry))) and not (isinstance(forward_node, nodes.AccessNode) + and isinstance(dest_node, nodes.AccessNode)): + # Check if we can call the memlet path on new_edge safely + path = backward_state.memlet_path(new_edge) + + # Get the source access node in the path + source_access_node = list(path)[0].src + if isinstance(source_access_node, nodes.AccessNode): + # Check if this is a zeroed out node + in_values = any(source_access_node in values for values in self.zeroed_out.values()) + if source_access_node.data != memlet.data and in_values: + memlet.data = source_access_node.data + self._set_wcr_if_needed(backward_state=backward_state, + backward_node=self.reverse_map[forward_node], + edge=new_edge) + + return new_backward_state + + def _get_accessnode_to_forward(self, forward_state: SDFGState, forward_node: nodes.AccessNode): + """ + Check if this AccessNode is at the base level of the state. If yes, this is the node we want to connect + Otherwise, in the case the AN is encolsed by maps, we walk up the maps until we find the source AN. + """ + scope_dict = forward_state.scope_dict()[forward_node] + is_base_level = scope_dict is None + if is_base_level: + return forward_node + else: + # The node is within a map nest + # It should have an in edge leading to the original AN + in_edges = forward_state.in_edges(forward_node) + assert len(in_edges) == 1 + + # Get the memlet path and the original AN + memlet_path = forward_state.memlet_path(in_edges[0]) + original_an = memlet_path[0] + assert isinstance(original_an, nodes.AccessNode) + + # This should be a base level AN + assert forward_state.scope_dict()[original_an] is None + return original_an + + def _connect_forward_inputs(self, state: SDFGState, backward_state: SDFGState, forward_node: nodes.Node, + backward_node: nodes.Node, required_inputs: Dict[str, str]) -> None: + """Connect the reversed node to all required non-gradient inputs. + + This function handles non-trivial routing scenarios: + 1. When reading from an AccessNode in forward pass, route through maps in backward pass + 2. Save connector values to arrays when backward pass needs to read them + + Currently supports two strategies: store-all and recompute-all. + + :param state: Forward state containing the forward node. + :param backward_state: Backward state containing the backward node. + :param forward_node: The forward pass node. + :param backward_node: The backward pass node (not necessarily a reversed node). + :param required_inputs: Maps forward pass connector names to backward pass connector names. + :raises AutoDiffException: If required connectors don't exist on forward node. + """ + + if set(required_inputs).difference(forward_node.in_connectors): + missing_connectors = \ + set(required_inputs).difference(forward_node.in_connectors) + raise AutoDiffException(f"Cannot connect connectors {missing_connectors} to {backward_node} " + f"because they don't exist on the corresponding forward node {forward_node}") + + # note we use forward state here: we might need to connect inputs that are not in the + # forward pass + input_edges_to_connect = (edge for edge in state.in_edges(forward_node) if edge.dst_conn in required_inputs) + + for edge in input_edges_to_connect: + # Boolean to decide if the source of this edge needs to be replicated + replicate_node = False + + # Boolean to decide if the connection to the replicated node is required + # This is set to False if the connection has already been established + connect_replicated_node = True + edge_src = edge.src + next_required_inputs: Dict[Optional[str], Optional[str]] + replicated_edge_src: nodes.Node + replicated_edge_src_conn: str + + if isinstance(edge_src, nodes.MapEntry): + # In the map case, we need to connect the AN at the start of this memlet path + memlet_path = state.memlet_path(edge) + + # Get the AccessNode at the start of this path + starting_edge = memlet_path[0] + starting_an = starting_edge.src + if not isinstance(starting_an, nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode at start of memlet path, got {type(starting_an)}") + + # Save the information about the data to be forwarded + # to call the function to connect this required AccessNode + # after the reversal + self.data_to_forward.append((state, backward_state, starting_an, forward_node, edge)) + + # No further recusrive calls are required + # in this branch; next_required_inputs stays empty + next_required_inputs = {} + + # Everything will be done in the connect forward accessnode function + replicate_node = False + connect_replicated_node = False + + elif isinstance(edge_src, nodes.AccessNode): + # Get the AccessNode to connect + an_to_connect = self._get_accessnode_to_forward(state, edge_src) + + # Save the information about the data to be forwarded + # to call the function to connect this required AccessNode + # after the reversal + self.data_to_forward.append((state, backward_state, an_to_connect, forward_node, edge)) + + # No further recusrive calls are required + # in this branch; next_required_inputs stays empty + next_required_inputs = {} + + # Everything will be done in the connect forward accessnode function + replicate_node = False + connect_replicated_node = False + + elif isinstance(edge_src, nodes.Tasklet): + + replicate_node = True + # In the tasklet case, we need to connect all inputs + next_required_inputs = {c: c for c in edge_src.in_connectors} + else: + raise AutoDiffException("Unsupported node") + + if replicate_node: + replicated_edge_src_conn = edge.src_conn + + # always replicate the access node + replicated_edge_src = copy.deepcopy(edge_src) + backward_state.add_node(replicated_edge_src) + + if connect_replicated_node: + new_edge_data = copy.deepcopy(edge.data) + if isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode): + # code->code edges have a small special case: + # we need to copy the descriptor + data_name = new_edge_data.data + data_desc = copy.deepcopy(self.sdfg.arrays[data_name]) + if self.separate_sdfgs: + self.backward_sdfg.add_datadesc(data_name, data_desc) + else: + new_data_name = self.backward_sdfg.add_datadesc(data_name, data_desc, find_new_name=True) + new_edge_data.data = new_data_name + + if isinstance(edge_src, nodes.AccessNode) and isinstance(data_desc, dt.View): + if self.separate_sdfgs: + # Remove the view connector + assert replicated_edge_src.remove_in_connector("views") + else: + # If this is a view, we need to connect it to the AccessNode it is viewing + edge_src_in_edge = state.in_edges(edge_src) + + # A view should only have one incoming edge + assert len(edge_src_in_edge) == 1 + edge_src_in_edge = edge_src_in_edge[0] + + # Replicate the viewed node and its memlet and connect it + view_origin = edge_src_in_edge.src + replicated_view = copy.deepcopy(view_origin) + view_memlet = copy.deepcopy(edge_src_in_edge.data) + if self.separate_sdfgs: + # If the SDFGs are separate, we need to add the descriptor for this data + origin_desc = self.sdfg.arrays[view_origin.data] + origin_desc.transient = False + backward_state.sdfg.add_datadesc(view_origin.data, origin_desc) + backward_state.add_edge(replicated_view, None, replicated_edge_src, "views", view_memlet) + + # Add the new edge + backward_state.add_edge(replicated_edge_src, replicated_edge_src_conn, backward_node, + required_inputs[edge.dst_conn], new_edge_data) + + if next_required_inputs: + # If there are any required inputs on the new node, we need to + # recursively call + self._connect_forward_inputs(state, backward_state, edge.src, replicated_edge_src, next_required_inputs) + + def _lookup_required_grad_name(self, node: nodes.Node, connector: str) -> str: + """Look up the required gradient name for a given node and connector. + + :param node: The forward pass node. + :param connector: The connector name to look up. + :return: The required gradient name for the connector. + :raises AutoDiffException: If the node's backward result is not available. + """ + if node not in self.result_map: + raise AutoDiffException(f"Attempted to access required gradient of {node} " + f"before the backward node was created") + return self.result_map[node].required_grad_names[connector] + + def _lookup_given_grad_name(self, node: nodes.Node, connector: str) -> str: + """Look up the given gradient name for a given node and connector. + + :param node: The forward pass node. + :param connector: The connector name to look up. + :return: The given gradient name for the connector. + :raises AutoDiffException: If the node's backward result is not available. + """ + if node not in self.result_map: + raise AutoDiffException(f"Attempted to access given gradient of {node} " + f"before the backward node was created") + return self.result_map[node].given_grad_names[connector] + + def _find_backward_entry_node_for_map_entry(self, backward_state: SDFGState, + entry_node: nodes.MapEntry) -> nodes.MapEntry: + """Find the entry node in the backward pass corresponding to a forward pass entry node. + + :param backward_state: The backward state to search in. + :param entry_node: The MapEntry node from the forward pass. + :return: The corresponding MapEntry node in the backward pass. + :raises AutoDiffException: If exactly one corresponding node is not found. + """ + src_candidates = [ + node for node in backward_state.nodes() + if isinstance(node, nodes.MapEntry) and node.map == self.reverse_map[entry_node.map] + ] + if len(src_candidates) != 1: + raise AutoDiffException(f"Expected exactly one backward MapEntry for forward MapEntry {entry_node}, " + f"but found {len(src_candidates)} candidates") + + return src_candidates[0] + + def _get_reverse_node(self, state: SDFGState, backward_state: SDFGState, node: nodes.Node, + given_gradients: List[str], + required_gradients: List[str]) -> Tuple[nodes.Node, BackwardResult]: + """Add the reverse node for a node from the forward pass to the backward pass. + + Resolution order: + 1) Check for methods on this class + 2) Check the backward pass repository + + :param state: Forward state containing the node. + :param backward_state: Backward state to add the reverse node to. + :param node: Node from the forward pass to reverse. + :param given_gradients: Output names on the forward node for gradient input connections. + :param required_gradients: Input names on the forward node that need gradients generated. + :return: Tuple of (reversed node, BackwardResult with gradient connector names). + :raises AutoDiffException: If no backward implementation is found for the node type. + """ + + # (1) + if hasattr(self.dace_node_impl, "_reverse_" + type(node).__name__): + reverse_method = getattr(self.dace_node_impl, f"_reverse_{type(node).__name__}") + return reverse_method(state, backward_state, node, given_gradients, required_gradients) + + # (2) + impl = find_backward_implementation(self.sdfg, forward_state=state, node=node) + if impl is not None: + backward_node, backward_result = impl.backward(forward_node=node, + context=BackwardContext( + forward_state=state, + forward_sdfg=self.sdfg, + backward_state=backward_state, + backward_sdfg=self.backward_sdfg, + backward_generator=self, + ), + given_gradients=given_gradients, + required_gradients=required_gradients) + if isinstance(backward_node, nodes.CodeNode): + backward_node.schedule = node.schedule + return backward_node, backward_result + + raise AutoDiffException(f"Unable to differentiate node type {type(node)}. " + f"Either add a pure forward implementation or a backward implementation to progress.") diff --git a/dace/autodiff/base_abc.py b/dace/autodiff/base_abc.py new file mode 100644 index 0000000000..c47cc83d0d --- /dev/null +++ b/dace/autodiff/base_abc.py @@ -0,0 +1,183 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Abstract Base Classes for Autodiff +""" +import abc +import dataclasses +import typing +from typing import TYPE_CHECKING + +import dace.registry +from dace import config +from dace.sdfg import SDFG, SDFGState, nodes as nd +import dace.transformation.transformation as xf + +if TYPE_CHECKING: + from dace.autodiff.backward_pass_generator import BackwardPassGenerator + +try: + from dace.libraries.onnx.nodes.onnx_op import ONNXOp + ONNX_AVAILABLE = True +except ImportError: + ONNXOp = None + ONNX_AVAILABLE = False + + +class AutoDiffException(Exception): + """Base class for all exceptions related to automatic differentiation failures.""" + pass + + +@dataclasses.dataclass +class BackwardContext: + """A tuple holding the graph context required to construct reverse nodes.""" + forward_sdfg: SDFG #: The forward SDFG + forward_state: SDFGState #: The forward SDFG state + backward_sdfg: SDFG #: The backward SDFG + backward_state: SDFGState #: The backward SDFG state + backward_generator: 'BackwardPassGenerator' #: The backward pass generator + + +@dataclasses.dataclass +class BackwardResult: + """The return type of a differentiated node. It contains the names of the gradients the node calculates and + requires. + """ + + #: Mapping from names of output connectors to the connector name of the gradient for that connector. + required_grad_names: typing.Dict[typing.Optional[str], typing.Optional[str]] + + #: Mapping from names of input connectors to the connector name of the gradient for that connector. + given_grad_names: typing.Dict[typing.Optional[str], typing.Optional[str]] + + #: Mapping from names of gradients to whether they should be zeroed out on initialization. + zero_init: typing.Dict[typing.Optional[str], typing.Optional[bool]] + + def __init__(self, required_grad_names, given_grad_names, zero_init=None): + self.required_grad_names = required_grad_names + self.given_grad_names = given_grad_names + self.zero_init = zero_init or {} + + @staticmethod + def empty(): + """Create an empty BackwardResult with no gradients.""" + return BackwardResult(given_grad_names={}, required_grad_names={}, zero_init={}) + + +@dace.registry.make_registry +class BackwardImplementation(abc.ABC): + """ABC for backward implementations. + + This registry accepts two types of registrations. + The register function expects an argument ``node_type=TYPE`` where ``TYPE`` is the type of node that this + backward implementation supports. + It can also take an argument ``op=node_name`` where ``node_name`` is the string of the ONNX op it supports, + e.g. ``"Conv"``. + + It also expects a ``name`` argument that names the implementation. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: SDFGState, sdfg: SDFG) -> bool: + """Return whether this expansion can be applied. + + :param node: The candidate node. + :param state: The candidate state. + :param sdfg: The candidate SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + return True + + @staticmethod + @abc.abstractmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], + required_gradients: typing.List[typing.Optional[str]]) -> typing.Tuple[nd.Node, BackwardResult]: + """Add the reverse node for a node from the forward pass to the backward pass, and return it. + + For each input connector with name ``n`` of the forward in required_gradients, the returned backward node must + add an output connector with name ``required_gradients[n]`` that will output the gradient for that input. + + If any input from the forward pass is required, simply add a connector with the same name as the connector + on the forward node. The input will later be connected as required. + + :param forward_node: The node for which the backward pass should be generated. + :param context: The context for this node (see + :class:`~dace.autodiff.backward_implementation.BackwardContext`). + :param given_gradients: The names of outputs of the node that gradients will be connected for. + :param required_gradients: The names of connectors that gradients should be generated for. + :return: The reverse node and gradient names + (see :class:`~dace.autodiff.backward_implementation.BackwardResult`). + """ + ... + + +# Register the implementations +import dace.autodiff.implementations + + +def find_backward_implementation(forward_sdfg: SDFG, forward_state: SDFGState, + node: nd.Node) -> typing.Optional[BackwardImplementation]: + """Try to find the backward implementation for ``node``. + + :param forward_sdfg: The parent SDFG of the node. + :param forward_state: The parent SDFG state of the node. + :param node: The node to find the implementation for. + :return: The BackwardImplementation for node if one is registered and can be applied, else None. + """ + valid_impls = [] + for impl, args in BackwardImplementation.extensions().items(): + if "name" not in args: + raise ValueError(f"Expected name in arguments of implementation {impl}.") + + if "node_type" in args and isinstance(node, args["node_type"]) or (ONNX_AVAILABLE and isinstance(node, ONNXOp) + and "op" in args + and node.schema.name == args["op"]): + + if impl.backward_can_be_applied(node, forward_state, forward_sdfg): + valid_impls.append((args["name"], impl)) + + if ONNX_AVAILABLE and isinstance(node, ONNXOp) and node.backward_implementation: + + implementation = node.backward_implementation + elif ONNX_AVAILABLE and isinstance(node, ONNXOp) and node.default_backward_implementation: + implementation = node.default_backward_implementation + else: + implementation = None + + if implementation: + filtered_impls = [i for name, i in valid_impls if name == implementation] + if filtered_impls: + return filtered_impls[0] + + if config.Config.get_bool('debugprint'): + print(f"Warning: Set backward_implementation {node.backward_implementation} on {node}, but it could not be" + f" applied. Falling back to default selection.") + if valid_impls: + return valid_impls[0][1] + else: + return None + + +class ExpansionTemplate(xf.ExpandTransformation): + """Module-level expansion class for operations during autodiff. + + This class is used by BackwardPassGenerator._expand_nodes to expand operations + that don't have backward implementations. It needs to be at module level for serialization. + + The class is dynamically configured before use by setting: + - environments: List of required environments + - _impl: The implementation object containing the forward method + - _match_node: The pattern node to match + """ + environments = [] + _impl = None + + @classmethod + def expansion(cls, node, state, sdfg): + if cls._impl is None: + raise RuntimeError("_ONNXExpansion._impl must be set before expansion") + return cls._impl.forward(node, state, sdfg) + + @staticmethod + def annotates_memlets() -> bool: + return True diff --git a/dace/autodiff/data_forwarding/__init__.py b/dace/autodiff/data_forwarding/__init__.py new file mode 100644 index 0000000000..da93544194 --- /dev/null +++ b/dace/autodiff/data_forwarding/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data Forwarding Strategies for Automatic Differentiation. + +This package manages the tradeoff between storing intermediate values and +recomputing them during the backward pass. This is a fundamental memory-time +tradeoff in automatic differentiation. +""" + +from .manager import DataForwardingManager +from .store import resolve_overwrite_with_store +from .recompute import get_recomputation_nsdfg, resolve_overwrite_with_recomputation + +__all__ = [ + "DataForwardingManager", + "resolve_overwrite_with_store", + "resolve_overwrite_with_recomputation", + "get_recomputation_nsdfg", +] diff --git a/dace/autodiff/data_forwarding/manager.py b/dace/autodiff/data_forwarding/manager.py new file mode 100644 index 0000000000..cb510632ea --- /dev/null +++ b/dace/autodiff/data_forwarding/manager.py @@ -0,0 +1,388 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple, Optional + +# DaCe imports +import dace.sdfg.nodes as nodes +from dace import config, data as dt +from dace.sdfg import SDFGState, graph as dgraph + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils +import dace.autodiff.data_forwarding as data_forwarding + + +class DataForwardingManager: + + def __init__(self, bwd_generator: 'BackwardPassGenerator'): + + # The user specified strategy for forwarding + # Whether to forward data through separate SDFGs + self.bwd_generator: 'BackwardPassGenerator' = bwd_generator + + def forward_data_to_backward_pass(self) -> None: + """ + Iterate through all the data that needs to be forwarded to the backward pass states. + """ + # Get the strategy decision for each data that needs to be forwarded to the backward pass + strategy_choice, recomputation_nsdfgs = self._get_overwrite_resolution_strategy() + + # Make the connection according to the chosen strategy + for index, (forward_state, backward_state, access_node, node, + edge) in enumerate(self.bwd_generator.data_to_forward): + self._connect_forward_accessnode(forward_state, backward_state, access_node, node, edge, + recomputation_nsdfgs[index], strategy_choice[index]) + + def _get_overwrite_resolution_strategy(self) -> Tuple[List[str], List[Optional[nodes.NestedSDFG]]]: + """ + Choose a strategy for resolving overwritten data that we need to forward to the backward pass. + If the user wants a specific strategy, we use it. + Otherwise, we evaluate what strategy is best for this specific node. + """ + strategy_choice: List[str] = [] + recomputation_nsdfgs: List[Optional[nodes.NestedSDFG]] = [] + + # As preprocessing step, + # We will store all of the global program inputs, + # if they are required for the backward pass + # NOTE: This can be relaxed since if an input is not overwritten + # it can be recomputed + to_remove = [] + for i, (forward_state, backward_state, access_node, node, + edge) in enumerate(self.bwd_generator.data_to_forward): + if access_node.data not in self.bwd_generator.sdfg.arg_names: + continue + + # Store the input + self._connect_forward_accessnode(forward_state, backward_state, access_node, node, edge, None, "store") + + # Remove this element from the list of the data to forward + to_remove.append(i) + + # Remove elements from the list of data to be forwarded (in reverse order to maintain indices) + for idx in sorted(to_remove, reverse=True): + del self.bwd_generator.data_to_forward[idx] + + if self.bwd_generator.data_forwarding_strategy == "store_all": + strategy_choice = ["store"] * len(self.bwd_generator.data_to_forward) + + # A recomputation block is not necessary + recomputation_nsdfgs = [None] * len(self.bwd_generator.data_to_forward) + elif self.bwd_generator.data_forwarding_strategy == "recompute_all": + strategy_choice = ["recompute"] * len(self.bwd_generator.data_to_forward) + + # We will delay getting the recomputation block for now + recomputation_nsdfgs = [None] * len(self.bwd_generator.data_to_forward) + elif self.bwd_generator.data_forwarding_strategy == "user_defined": + if self.bwd_generator.data_to_recompute is None: + raise AutoDiffException("The overwrite resolution strategy is User Defined " + "but no recomputation list has been provided." + "Please set the data_to_recompute parameter.") + + for forward_state, backward_state, access_node, node, edge in self.bwd_generator.data_to_forward: + + if access_node.data in self.bwd_generator.data_to_recompute: + try: + nsdfg = data_forwarding.get_recomputation_nsdfg(self.bwd_generator, forward_state, access_node) + choice = "recompute" + except Exception as e: + # If anything goes wrong, print a warning and fall back to storing + if config.Config.get_bool('debugprint'): + print( + f"Warning: Couldn't get the recomputation nested SDFG for {access_node.label} because {e}" + ) + nsdfg = None + choice = "store" + recomputation_nsdfgs.append(nsdfg) + strategy_choice.append(choice) + else: + # We store everything else + recomputation_nsdfgs.append(None) + strategy_choice.append("store") + else: + raise AutoDiffException("Please specify a valid overwrite resolution strategy. " + "Expected either store_all, recompute_all, or user_defined " + f"but got {self.bwd_generator.data_forwarding_strategy}") + return strategy_choice, recomputation_nsdfgs + + def _connect_forward_accessnode(self, forward_state: SDFGState, backward_state: SDFGState, + forward_node: nodes.AccessNode, target_node: nodes.Node, + starting_edge: dgraph.MultiConnectorEdge, + recomputation_nsdfg: Optional[nodes.NestedSDFG], strategy: str): + """ + We need to forward an array from the forward pass to the backward pass. + To do this we first check if this array has been overwritten or not. + If the array has not been overwritten, we just need to replicate it + in the backward pass and then forward it. + If the array has been overwritten, we pick a strategy for this AccessNode: + - Store strategy: + - We modify the forward pass to save the values in a new array + - Connect this new array to the node in the backward pass + - Recomputation: + - Add the recomputation as a NestedSDFG + - Connect the output of the NestedSDFG to the node in the backward pass + """ + + # First, we check if the node has been overwritten + overwritten, recomputable = self._check_node_overwrite(forward_state=forward_state, node=forward_node) + + # Boolean indicating whether we should fall back to storing + fallback = False + if strategy == "recompute" and recomputable: + try: + if recomputation_nsdfg is None: + recomputation_nsdfg = data_forwarding.get_recomputation_nsdfg(self.bwd_generator, + forward_state, + target_an=forward_node) + data_forwarding.resolve_overwrite_with_recomputation(recomputation_nsdfg=recomputation_nsdfg, + forward_state=forward_state, + backward_state=backward_state, + target_an=forward_node, + target_node=target_node, + starting_edge=starting_edge) + except Exception as e: + # If anything goes bad, print a warning and fall back to storing + if config.Config.get_bool('debugprint'): + print(f"Warning: Failed to recompute {forward_node.data}: {e}. Falling back to storing") + fallback = True + + if strategy == "store" or (strategy == "recompute" and not recomputable) or fallback: + # We store if: + # - This was the specified strategy + # - We tried to recompute a program input + # - We tried to recompute something that didn't work and we're falling back to storing + + # The data has been overwritten + if not overwritten: + # We still have access to this data + self._connect_forward_accessnode_not_overwritten(forward_state, backward_state, forward_node, + target_node, starting_edge) + return + + data_forwarding.resolve_overwrite_with_store(bwd_generator=self.bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + forward_node=forward_node, + target_node=target_node, + starting_edge=starting_edge) + + def _check_node_overwrite(self, forward_state: SDFGState, node: nodes.AccessNode) -> Tuple[bool, bool]: + """ + Given an AccessNode from the forward state, check if the data of this node has changed. + We look at all the AccessNodes with the same data that occur after the 'node' parameter + if any of them has an incoming edge, return the node has been overwritten. + + :param node: the AccessNode to perform the check for. + :return: a tuple of whether this node has been overwritten, and if it can be recomputed + """ + overwritten = False + decided = False + recomputable = False + + # Get the descendant and ascendant states to look in for an overwrite + if forward_state not in self.bwd_generator.state_order: + raise AutoDiffException(f"Forward state {forward_state} not found in state order") + index = self.bwd_generator.state_order.index(forward_state) + descendant_states = self.bwd_generator.state_order[index:] + + # Check if this access node is a view + if isinstance(node.desc(self.bwd_generator.sdfg), dt.ArrayView): + # The view should have one incoming edge from the original access node + in_edges = forward_state.in_edges(node) + + # Sanity checks + if len(in_edges) != 1: + raise AutoDiffException(f"Expected exactly one incoming edge for view node {node}, got {len(in_edges)}") + if "views" not in node.in_connectors: + raise AutoDiffException(f"Expected 'views' connector in node {node}, but not found") + + # We want to check if the source has been overwritten + node = in_edges[0].src + + # Get all the AccessNodes with the same data + matches = [] + for d_state in descendant_states: + matches += [(nd, parent) for nd, parent in d_state.all_nodes_recursive() + if isinstance(nd, nodes.AccessNode) and nd.data == node.data] + + # There needs to be at least one occurrence which is the node passed as a parameter + if len(matches) == 0 or (node, forward_state) not in matches: + raise AutoDiffException(f"Node {node} not found in descendant states") + + # If there is only one occurrence of this data, it will not be overwritten later in the graph + if len(matches) == 1: + overwritten = False + decided = True + + # Get the index of the parameter node + index = matches.index((node, forward_state)) + + # If the parameter node is the last occurrence in the descendant states, + # it will not be overwritten + if len(matches) - 1 == index: + overwritten = False + decided = True + + # If we haven't already confirmed that this node has not been overwritten + if not decided: + # Iterate through all the successor occurrences + for nd, parent in matches[index + 1:]: + # Check if this node has an incoming edge + if len(parent.in_edges(nd)) > 0: + overwritten = True + + if not overwritten: + # There is no overwrite so far + # Check if this state is within a loop + is_in_loop, loop = ad_utils.state_within_loop(forward_state) + if is_in_loop: + + # Check if there is any write to this access node within the loop + loop_matches = [(nd, parent) for nd, parent in loop.all_nodes_recursive() + if isinstance(nd, nodes.AccessNode) and nd.data == node.data] + for match, match_parent in loop_matches: + # Check if this node has an incoming edge + if len(match_parent.in_edges(match)) > 0: + overwritten = True + + if overwritten and len(matches) == 1: + # Check if the overwrite is from constant arrays + # This means that the same value will be assigned at each iteration of the loop + # And no storing is necessary + match, match_parent = loop_matches[0] + all_read_only = True + for edge in match_parent.edge_bfs(match, reverse=True): + if edge.data.subset is not None and len(edge.data.subset.free_symbols) != 0: + all_read_only = False + break + if isinstance(edge.src, nodes.AccessNode): + # The memlet needs to be constant + if edge.src.data not in self.bwd_generator.read_only_arrays: + all_read_only = False + break + # Check if the data is read only + if all_read_only: + overwritten = False + + # Iterate through all the predecessor occurrences + for nd, parent in matches[:index + 1]: + # Check if this node has an incoming edge + if len(parent.in_edges(nd)) > 0: + recomputable = True + return overwritten, recomputable + + def _connect_forward_accessnode_not_overwritten(self, + forward_state: SDFGState, + backward_state: SDFGState, + forward_node: nodes.AccessNode, + target_node: nodes.Node, + starting_edge: dgraph.MultiConnectorEdge, + replicated_node: Optional[nodes.AccessNode] = None): + """ + Replicate and connect the forward AccessNode to the requesting node in the backward pass. + Because the AccessNode has not been overwritten, we just need to create the same connection + in the backward pass. + """ + + # First, replicate the AccessNode and add it to the backward pass + # If it has not already been replicated and passed as a parameter + if replicated_node is None: + replicated_node = copy.deepcopy(forward_node) + backward_state.add_node(replicated_node) + if self.bwd_generator.separate_sdfgs: + # Need to copy over the descriptor from the forward pass + data_name = replicated_node.data + data_desc = copy.deepcopy(forward_node.desc(self.bwd_generator.sdfg)) + data_desc.transient = False + if data_name not in self.bwd_generator.backward_sdfg.arrays: + self.bwd_generator.backward_sdfg.add_datadesc(data_name, data_desc) + + # We also need to forward this array + if data_name not in self.bwd_generator.backward_input_arrays: + # If the data is needed inside a NestedSDFG + # This will make sure the added array is correctly forwarded + # and an in connector to the NestedSDFG is added + self.bwd_generator.backward_input_arrays[data_name] = data_desc + + # We replicate the exact link between this forward access node and the target node + # Get all the edges in the path + all_edges_inbetween = ad_utils.get_all_path_edges(state=forward_state, + source=forward_node, + starting_edge=starting_edge) + + # A dictionary to keep track of temporary nodes in the path + replicated_tmp_nodes = {} + + # For each edge in the path + for edge in all_edges_inbetween: + src, src_conn, dst, dst_conn, data = edge + bwd_src, bwd_src_conn, bwd_dst, bwd_dst_conn, bwd_data = src, src_conn, dst, dst_conn, copy.deepcopy(data) + + # If the destination is a map entry, + if isinstance(dst, nodes.MapEntry): + # We need to get the corresponding map entry in the backward pass. + bwd_dst = self.bwd_generator._find_backward_entry_node_for_map_entry(backward_state=backward_state, + entry_node=dst) + # Add the dst connector to the map + added = bwd_dst.add_in_connector(bwd_dst_conn) + assert added + + # If the destination is a map entry, + if isinstance(src, nodes.MapEntry): + # We need to get the corresponding map entry in the backward pass. + bwd_src = self.bwd_generator._find_backward_entry_node_for_map_entry(backward_state=backward_state, + entry_node=src) + # Add the src connector to the map + added = bwd_src.add_out_connector(bwd_src_conn) + assert added + + if src is forward_node: + # If this is the node we replicated + bwd_src = replicated_node + elif isinstance(src, nodes.AccessNode): + # This is a temporary AccessNodes + # we should have already seen and replicated this + assert src in replicated_tmp_nodes + bwd_src = replicated_tmp_nodes[src] + + if dst is target_node: + # If this is the final connection node + bwd_dst = self.bwd_generator.reverse_map[dst] + elif isinstance(dst, nodes.AccessNode): + # This is a temporary AccessNodes + # we want to replicate and add it to the path + bwd_dst = copy.deepcopy(dst) + backward_state.add_node(bwd_dst) + replicated_tmp_nodes[dst] = bwd_dst + + # Modify the data in the memlet in case the array is replicated outside of the function + bwd_data.data = replicated_node.data + + # Add the edge to the backward state + backward_state.add_edge(bwd_src, bwd_src_conn, bwd_dst, bwd_dst_conn, bwd_data) + + # If we just connected a view, we need to remove the view in connector + data_desc = self.bwd_generator.sdfg.arrays[forward_node.data] + if isinstance(forward_node, nodes.AccessNode) and isinstance(data_desc, dt.View): + if self.bwd_generator.separate_sdfgs: + # Remove the view connector + assert replicated_node.remove_in_connector("views") + else: + # if this is a view, we need to connect it to the AccessNode it is viewing + edge_src_in_edge = forward_state.in_edges(forward_node) + + # a view should only have one incoming edge + assert len(edge_src_in_edge) == 1 + edge_src_in_edge = edge_src_in_edge[0] + + # replicate the viewed node and its memlet and connect it + view_origin = edge_src_in_edge.src + replicated_view = copy.deepcopy(view_origin) + view_memlet = copy.deepcopy(edge_src_in_edge.data) + if self.bwd_generator.separate_sdfgs: + # if the sdfgs are separate, we need to add the descriptor for this data + origin_desc = self.bwd_generator.sdfg.arrays[view_origin.data] + origin_desc.transient = False + backward_state.sdfg.add_datadesc(view_origin.data, origin_desc) + backward_state.add_edge(replicated_view, None, replicated_node, "views", view_memlet) diff --git a/dace/autodiff/data_forwarding/recompute.py b/dace/autodiff/data_forwarding/recompute.py new file mode 100644 index 0000000000..ce6330dbff --- /dev/null +++ b/dace/autodiff/data_forwarding/recompute.py @@ -0,0 +1,298 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List + +# DaCe imports +import dace +import dace.sdfg.nodes as nodes +from dace.sdfg import SDFG, SDFGState, graph as dgraph, state as dstate +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils + + +def resolve_overwrite_with_recomputation( + recomputation_nsdfg: nodes.NestedSDFG, + forward_state: SDFGState, + backward_state: SDFGState, + target_an: nodes.AccessNode, + target_node: nodes.Node, + starting_edge: dstate.MultiConnectorEdge, +): + """ + Experimental! Use recomputation in the backward pass to compute data that was overwritten in the forward pass. + """ + + # Add the nsdfg where it is required + _connect_recomputation_nsdfg(forward_state=forward_state, + backward_state=backward_state, + nsdfg=recomputation_nsdfg, + target_an=target_an, + target_node=target_node, + starting_edge=starting_edge) + + +def _connect_recomputation_nsdfg(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, target_an: nodes.AccessNode, target_node: nodes.Node, + nsdfg: nodes.NestedSDFG, starting_edge: dstate.MultiConnectorEdge): + """ + + """ + # Connect all the SDFG inputs to the nested SDFG + # First, add the nested sdfg + for input in nsdfg.in_connectors.keys(): + # For each argument + input_name = input if "recomputation_" not in input else input[14:] + + # Get the first instance of this AN in the SDFG + first_instance = None + for node, parent in bwd_generator.forward_sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == input: + first_instance = node + first_node_state = parent + break + + assert first_instance + + new_an = nodes.AccessNode(input_name) + backward_state.add_node(new_an) + + # Create a memlet passing all the data to the nested-SDFG + memlet = bwd_generator.forward_sdfg.make_array_memlet(input_name) + + # Add the connection to the nested SDFG + backward_state.add_edge(new_an, None, nsdfg, input, memlet) + + # Write the data to a new access node in the backward state + # Add a new AccessNode and array to the forward pass + # First, check if a recomputated array with this name already exists + if "recomputed_" + target_an.data not in bwd_generator.backward_sdfg.arrays: + new_recomp_node_name = "recomputed_" + target_an.data + else: + i = 0 + while True: + if f"recomputed_{i}_" + target_an.data not in bwd_generator.backward_sdfg.arrays: + new_recomp_node_name = f"recomputed_{i}_" + target_an.data + break + i += 1 + + # Get the new array shape + # This will be the shape of the current array + shape: List[int] = list(bwd_generator.forward_sdfg.arrays[target_an.data].shape) + + # Add the array descriptor and AccessNode to the forward state + original_desc = target_an.desc(forward_state) + new_recomp_node = backward_state.add_array( + name=new_recomp_node_name, + shape=shape, + dtype=original_desc.dtype, + transient=True, + ) + new_recomp_node.setzero = True + + # Create a memlet passing all the data to the nested-SDFG + memlet = bwd_generator.forward_sdfg.make_array_memlet(new_recomp_node.data) + + nsdfg_out_conn = list(nsdfg.out_connectors.keys()) + assert len(nsdfg_out_conn) == 1 + nsdfg_out_conn = nsdfg_out_conn[0] + + # Connect the output of the NestedSDFG + backward_state.add_edge(nsdfg, nsdfg_out_conn, new_recomp_node, None, memlet) + + # Connect the new AccessNode to the required computation + bwd_generator._connect_forward_accessnode_not_overwritten(forward_state=forward_state, + backward_state=backward_state, + forward_node=target_an, + target_node=target_node, + starting_edge=starting_edge, + replicated_node=new_recomp_node) + + +def _prune_descendants_recomputation_nsdfg(forward_state: SDFGState, target_an: nodes.AccessNode, + nsdfg: nodes.NestedSDFG): + """ + 1: From this Nested-SDFG, we remove everything that will be executed after the target access node to be recomputed + 2: Prune the unnecessary computation inside the forward state + Note: this is even necessary sometimes since the output could be overwritten in the same state + """ + + # 1 + # Get the states order for the nested_sdfg + states_order: List[SDFGState] = ad_utils.get_state_topological_order(nsdfg.sdfg) + state_index = states_order.index(forward_state) + descendant_states: List[SDFGState] = states_order[state_index:] + assert descendant_states.pop(0) == forward_state + + # Check if the target state is within a loop + target_within_loop, target_loop = ad_utils.state_within_loop(forward_state) + + # We will save the states that are within the same loop because they require special treatement + same_loop_states: List[SDFGState] = [] + for state in descendant_states: + # We want to avoid removing the descendant states that are inside the same loop region + if target_within_loop: + descendant_within_loop, descendant_loop = ad_utils.state_within_loop(state) + if descendant_within_loop and descendant_loop == target_loop: + # If the state is within the same loop, we don't remove it + same_loop_states.add(state) + continue + + # Remove the state from the nested_sdfg + parent = state.parent_graph + parent.remove_node(state) + + # Cleanup empty LoopRegions if any + for node in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, LoopRegion) and len(node.nodes()) == 0: + parent = node.parent_graph + parent.remove_node(node) + + # 2 + # Within the same state + if target_within_loop: + # For now we keep all of the computation inside the loop + # TODO: if there is an overwrite to the same array in the decendnat computation + # We need to make a special case for the last iteration of the loop where the + # else branch of this if is executed and a special version of the loop is added + raise AutoDiffException("Recomputation with overwrites within loops is not supported yet.") + else: + # If the target state is not within a loop + # We remove all the descendant computation from the graph + + # Do a reverse bfs to get all the necessary computation + backward_nodes = {n for e in forward_state.edge_bfs(target_an, reverse=True) for n in [e.src, e.dst]} + + # Remove everything else + descendant_nodes = set(forward_state.nodes()) - backward_nodes + + for node in descendant_nodes: + if node is not target_an: + forward_state.remove_node(node) + + +def _prune_recomputation_sdfg(forward_state: SDFGState, target_an: nodes.AccessNode, nsdfg: nodes.NestedSDFG): + """ + 1: From this Nested-SDFG, we remove everything that will be executed after the target access node to be recomputed + 2: Prune the unnecessary computation inside the forward state + Note: this is even necessary sometimes since the output could be overwritten in the same state + 3: TODO: From the target access node, we go backward in the graph and see what elements are required to get this array + """ + + # 1 and 2 + _prune_descendants_recomputation_nsdfg(forward_state=forward_state, target_an=target_an, nsdfg=nsdfg) + + +def _rename_descriptors_for_recomputation_nsdfg(forward_sdfg: SDFG, nsdfg: nodes.NestedSDFG): + """ + """ + # Get all the nodes to rename in the NestedSDFG + to_rename = [] + for inp in nsdfg.in_connectors: + for node, parent in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == inp and parent.in_degree(node) > 0: + # This is an input that will be written to in the SDFG we need to rename it + to_rename.append(inp) + break + + if len(to_rename) > 0: + # Add a new state to copy the data at the start of the SDFG + initi_state = nsdfg.sdfg.add_state_before(nsdfg.sdfg.start_state, label=f"init_{nsdfg.label}") + + # Rename the descriptors in the nested SDFG in addition to the in connector + for name in to_rename: + # Create a new array + new_name = f"recomputation_{name}" + + # Change the accessnodes in the NestedSDFG + for node, parent in nsdfg.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode) and node.data == name: + node.data = new_name + + # Change the memlets in the SDFG + for edge, parent in nsdfg.sdfg.all_edges_recursive(): + # Skip interstate edges + if isinstance(edge.data, dace.InterstateEdge): + continue + + if edge.data.data == name: + edge.data.data = new_name + + # Add the desciptor + old_desc = nsdfg.sdfg.arrays[name] + new_desc = copy.deepcopy(old_desc) + + # Check if this is the output of the recomputation block + if name not in nsdfg.out_connectors: + new_desc.transient = True + else: + new_desc.transient = False + + nsdfg.sdfg.add_datadesc(name=new_name, datadesc=new_desc) + + # Add a copy operation between the input node and the new descriptor + input_node = nodes.AccessNode(name) + new_node = nodes.AccessNode(new_name) + initi_state.add_node(input_node) + initi_state.add_node(new_node) + + # Add memory copy edge + initi_state.add_edge(input_node, None, new_node, None, forward_sdfg.make_array_memlet(name)) + + # Change the output if necessary + if name in nsdfg.out_connectors: + nsdfg.remove_out_connector(name) + nsdfg.add_out_connector(new_name) + + +def get_recomputation_nsdfg(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + target_an: nodes.AccessNode) -> nodes.NestedSDFG: + """ + Given an AccessNode for data that needs to be forwarded from the forward pass to the backward pass, + Return a nested SDFG that recomputes this data from input data. + """ + nsdfg_label = "recomputation_nsdfg_" + target_an.data + + # Initially, we will replicate the whole SDFG into a Nested-SDFG and connect it + # TODO: we likely need a copy of the SDFG before starting AD if separate_sdfgs + nsdfg = nodes.NestedSDFG(label=nsdfg_label, + sdfg=copy.deepcopy(bwd_generator.sdfg), + inputs=bwd_generator.sdfg.arg_names, + outputs=[target_an.data]) + + # We need to make sure the output inside the NestedSDFG is not a transient (anymore) + nsdfg.sdfg.arrays[target_an.data].transient = False + + # Find the same target node and state in the nsdfg + nsdfg_forward_state: SDFGState = None + nb_occurrences = 0 + for state in nsdfg.sdfg.states(): + if state.label == forward_state.label: + nsdfg_forward_state = state + nb_occurrences += 1 + + # Sanity check + assert nb_occurrences == 1 + assert nsdfg_forward_state + + # Find the target AccessNode within the state + nsdfg_target_node: nodes.AccessNode = None + nb_occurrences = 0 + for node in nsdfg_forward_state.nodes(): + if isinstance(node, nodes.AccessNode) and node.data == target_an.data and nsdfg_forward_state.node_id( + node) == forward_state.node_id(target_an): + nsdfg_target_node = node + nb_occurrences += 1 + + # Sanity check + assert nb_occurrences == 1 + assert nsdfg_target_node + + _prune_recomputation_sdfg(nsdfg=nsdfg, forward_state=nsdfg_forward_state, target_an=nsdfg_target_node) + + # Change descriptors if the inputs are written to + _rename_descriptors_for_recomputation_nsdfg(forward_sdfg=bwd_generator.sdfg, nsdfg=nsdfg) + + return nsdfg diff --git a/dace/autodiff/data_forwarding/store.py b/dace/autodiff/data_forwarding/store.py new file mode 100644 index 0000000000..a744884e36 --- /dev/null +++ b/dace/autodiff/data_forwarding/store.py @@ -0,0 +1,683 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from typing import List, Tuple +import sympy as sp + +# DaCe imports +import dace.sdfg.nodes as nodes +from dace import dtypes, data as dt, symbolic +from dace.sdfg import SDFGState, graph as dgraph, state as dstate +from dace.memlet import Memlet +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException +import dace.autodiff.utils as ad_utils + + +def resolve_overwrite_with_store(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, forward_node: nodes.AccessNode, target_node: nodes.Node, + starting_edge: dstate.MultiConnectorEdge): + """ + Given the AccessNode pointing to the data required by the backward pass, + We will save the values of this array in a new array and forward it to the backward pass. + """ + + # Modify the forward pass to save the data in a new array + new_stored_array, memlets = _store_data(bwd_generator=bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + forward_an=forward_node, + target_node=target_node, + edge=starting_edge) + + # Check if this data needs to be forwarded through NestedSDFGs + if bwd_generator.separate_sdfgs or forward_state.sdfg.parent_sdfg is not None: + # We need to make sure the new array is forwarded to the backward SDFG + if new_stored_array.data not in bwd_generator.backward_input_arrays: + # If the data is needed inside a NestedSDFG + # This will make sure the added array is correctly forwarded + # and an in connector to the NestedSDFG is added + data_desc = new_stored_array.desc(forward_state) + bwd_generator.backward_input_arrays[new_stored_array.data] = data_desc + + # Connect the new array to the target node + _connect_stored_data_to_target(bwd_generator=bwd_generator, + forward_state=forward_state, + backward_state=backward_state, + source_node=new_stored_array, + forward_node=forward_node, + starting_edge=starting_edge, + memlets=memlets, + target_node=target_node) + + +def _store_data(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, backward_state: SDFGState, + forward_an: nodes.AccessNode, target_node: nodes.Node, + edge: dgraph.MultiConnectorEdge) -> Tuple[nodes.AccessNode, List[Memlet]]: + """ + Given an edge leading an AccessNode or a map to the target node in the forward state, + add a path from the connector for this AccessNode to store its values for all iterations. + This can increase the dimension of the array. i.e. the size of the stored array is + greater or equal to the size of the original array. + + :param edge: the edge connecting the AccessNode to save data from to a map node. + :return: the new AccessNode which contains the stored data, + a list of memlets connecting an assign tasklet to this new AccessNode. + """ + + # Get the connector and edge to save + if isinstance(edge.src, nodes.AccessNode) and edge.src is not forward_an: + + # Get the incoming edge to this AccessNode + in_edges = forward_state.in_edges(edge.src) + + # There should only be one incoming edge + assert len(in_edges) == 1 + + # Get the memlet path for the edge incoming to this AccessNode + memlet_path = forward_state.memlet_path(in_edges[0]) + + # The start of this path should be the forward AccessNode + assert forward_an is memlet_path[0].src + + # The last edge in the memlet path has the connector we want to save + edge = memlet_path[-1] + + # Add a new AccessNode and array to the forward pass + # First, check if a stored array with this name already exists + new_store_node_name = forward_state.sdfg._find_new_name("stored_" + forward_an.data) + + # Get the new array shape + # This will be the shape of the current array + shape: List[int] = list(bwd_generator.sdfg.arrays[forward_an.data].shape) + + # If the shape is an expression: + free_symbols_dict = {sym: None for sym in bwd_generator.sdfg.free_symbols} + if any(symbolic.issymbolic(s, free_symbols_dict) for s in shape): + # Otherwise, replace all the loop dependent allocations with the max length of the loop + # For example, an array of size [i+1] in a range(2, 10) loop will be stored in a [10, 10] array (1) + # Additionally, an array of size [32-i] in the same loop will be stored in a [10, 30] (2) + loops = _get_all_enclosing_loops(forward_state) + + if len(loops) > 0: + # Loop over the shape dimensions + for i, s in enumerate(shape): + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, s): + loop_size, loop_index = _get_symbol_upper_bound_from_loop(bwd_generator, s, loops) + # Replace the symbol with the loop size and evaluate the expression + # Check if loop size can be converted to an integer + loop_index_sym = symbolic.pystr_to_symbolic(loop_index) + loop_size_sym = loop_size if isinstance(loop_size, int) else symbolic.pystr_to_symbolic(loop_size) + shape[i] = s.subs(loop_index_sym, loop_size_sym) + + # Plus the size of any enclosing loops + enclosed, _ = ad_utils.state_within_loop(forward_state=forward_state) + nb_enclosing_loops = 0 + loop_param_list = [] + if enclosed: + # Get all enclosing loops + all_encolsing_loops = _get_all_enclosing_loops(forward_state=forward_state) + nb_enclosing_loops = len(all_encolsing_loops) + # Get the size of each loop and add it to the list + for loop in all_encolsing_loops: + # Get the end of the loop + start, end = ad_utils.extract_loop_region_info(loop) + + # Check if the loop is increasing or decreasing + # First, try to convert the strings to ints if possible + # Note that we look for the start or end of the loop + # And not the size of the loop. + # This is because we access using the loop indices + # Using the loop sizes instead would require shifting accesses + _, new_dim = ad_utils.get_loop_end(start, end, loop) + + # First we check if the new dimension contains symbols + # These will need to be replaced with scalars for correct allocation + # The sdfg symbols are allowed to be in the shape + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, new_dim): + # Take the expression to sympy for easier processing + if isinstance(new_dim, str): + new_dim = symbolic.pystr_to_symbolic(new_dim) + + # Try to replace the symbols with the loop size + loop_size, loop_index = _get_symbol_upper_bound_from_loop(bwd_generator, new_dim, all_encolsing_loops) + loop_index_sym = symbolic.pystr_to_symbolic(loop_index) + loop_size_sym = loop_size if isinstance(loop_size, int) else symbolic.pystr_to_symbolic(loop_size) + new_dim = new_dim.subs(loop_index_sym, loop_size_sym) + shape.insert(0, new_dim) + loop_param_list.insert(0, loop.loop_variable) + + # Add the array descriptor and AccessNode to the forward state + original_desc = forward_an.desc(forward_state) + + # We make a special case for a memlet of the type A[i, j] in an i, j loop + # In this case we only need an array of the same size as the forward node + if enclosed and edge.data.data == forward_an.data and len(edge.data.subset) == nb_enclosing_loops: + # Check that the memlet subset matches perfectly the order of loop nest + # Make sure the subset elements are (i,i,1) and (j,j,1) + # Then check if this matches the loop indices + if all( + str(subset[0]) == loop_param_list[i] and subset[0] == subset[1] and subset[2] == 1 + for i, subset in enumerate(edge.data.subset)): + # We only use the loop accesses + # Both should work since shape[:nb_enclosing_loops] == shape[nb_enclosing_loops:] + shape = shape[nb_enclosing_loops:] + + # We want to build the memlet as if this was not in a a loop + nb_enclosing_loops = 0 + + new_store_node = forward_state.add_array( + name=new_store_node_name, + shape=shape, + dtype=original_desc.dtype, + transient=True, + ) + + # Connect the edge source and connector to the new access node + # We will save the memlets we create and return them + # This is useful to make the connections for the backward state + memlets_stack = [] + + # The loop accesses will be the same within the state + # Prepare them for all edges + loop_access = ','.join([f'{loop_param_list[i]}' for i in range(nb_enclosing_loops)]) + + # In the other cases, we need to route the storing through maps + all_edges = ad_utils.get_all_path_edges(forward_state, forward_an, edge) + + # Get the map nest memlet informtation + start_range, param_list, shape_list, param_dict = ad_utils.get_map_nest_information(all_edges) + + # The parameters to add for the current memlet in the loop + # At first we will use all of the parameters that are used in the memlet + # param_dict = {key: val for key, val in param_dict.items() if key in edge.data.free_symbols} + new_param_dict = {} + + # Iterate through the subset + for index, element in enumerate(edge.data.subset): + if str(element[0]) in edge.data.free_symbols and str(element[0]) in param_dict.keys(): + # Add the range from the param_dict + new_param_dict.update({str(element[0]): param_dict[str(element[0])]}) + else: + # Add the range from the param_dict + new_param_dict.update({index: element}) + + params_to_add = new_param_dict + # First, we need to add an assign tasklet + assign_tasklet_node, assign_tasklet_node_out_connector = _get_assign_tasklet(forward_state=forward_state, + node=forward_an, + stored_node=new_store_node, + last_edge=edge, + loop_iterators=loop_access) + + # Start iterating + previous_node = assign_tasklet_node + previous_node_out_connector = assign_tasklet_node_out_connector + map_exist = None + for edge in reversed(all_edges): + if isinstance(edge.src, nodes.MapEntry): + # Get the corresponding map exit + map_exist = _find_map_exist_for_map_entry(map_entry=edge.src, state=forward_state) + + # Add the Connectors to the map + map_exit_in_connector = f"IN_stored_{new_store_node.label}" + map_exit_out_connector = f"OUT_stored_{new_store_node.label}" + added = map_exist.add_in_connector(map_exit_in_connector) + assert added + added = map_exist.add_out_connector(map_exit_out_connector) + assert added + + # Prepare the memlet data for this edge + access_list = [] + for key, val in new_param_dict.items(): + if isinstance(key, str): + if key in params_to_add.keys(): + access_list.append(key) + else: + start = val[0] + end = val[1] + access_list.append(f'{start}:{end}') + elif isinstance(key, int): + start = val[0] + end = val[1] + 1 + access_list.append(f'{start}:{end}') + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + in_state_access = ','.join(access_list) + + memlet_data = Memlet( + expr=f"{new_store_node.data}[{loop_access},{in_state_access}]") if loop_access else Memlet( + expr=f"{new_store_node.data}[{in_state_access}]") + + # Save the memlet for later + memlets_stack.append(memlet_data) + + # Connect the previous node to this map exist + forward_state.add_edge(previous_node, previous_node_out_connector, map_exist, map_exit_in_connector, + memlet_data) + + previous_node = map_exist + previous_node_out_connector = map_exit_out_connector + + # Remove the parameters seen in the current map + # Since they will become out of scope in the next iteration + params_to_add = {} + for key, val in new_param_dict.items(): + if isinstance(key, str): + if key not in edge.src.params: + start = val[0] + end = val[1] + params_to_add.update({key: (start, end)}) + elif isinstance(key, int): + params_to_add.update({key: val}) + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + else: + # Prepare the memlet data for this edge + access_list = [] + for key, val in new_param_dict.items(): + if isinstance(key, str): + start = val[0] + end = val[1] + access_list.append(f'{start}:{end}') + elif isinstance(key, int): + start = val[0] + end = val[1] + 1 + access_list.append(f'{start}:{end}') + else: + raise AutoDiffException("Found unexepected type in memlet parameters dictionary") + + in_state_access = ','.join(access_list) + + # Get the memlet data for the connection between the last map exit and the new store AccessNode + memlet_data = Memlet( + expr=f"{new_store_node.data}[{loop_access},{in_state_access}]") if loop_access else Memlet( + expr=f"{new_store_node.data}[{in_state_access}]") + + memlets_stack.append(memlet_data) + + # This should be the last connection + forward_state.add_edge(previous_node, previous_node_out_connector, new_store_node, None, memlet_data) + break + + # We need to add an empty memlet from the new store AccessNode to make sure the data is stored before it is + # potentially altered + # First, we check if this can be avoided + # We do a BFS exploration to see if the data we are trying to store is overwritten within the same execution state + bfs_nodes = list(forward_state.bfs_nodes(source=forward_an)) + + # We make sure that views are also compared with their original array to check for conflicts + conflict_arrays = [forward_an.data] + # Check if the access node is a view + if isinstance(forward_an.desc(forward_state), dt.View): + # Get the original array name + viewed_array = next(forward_state.in_edges_by_connector(forward_an, "views")).data.data + conflict_arrays.append(viewed_array) + + if any(isinstance(n, nodes.AccessNode) and n.data in conflict_arrays and n is not forward_an for n in bfs_nodes): + to_connect = [] + for out_edge in forward_state.out_edges(forward_an): + # Get the destination of the edge + dst = out_edge.dst + if not isinstance(dst, nodes.MapEntry) and dst is not assign_tasklet_node: + # This will not be necessary for maps since the storing is added to the same map + # We also don't connect the newly created assign tasklet to avoid creating a cycle + if dst not in to_connect: + # We only need to make a single connection to the new stored data + to_connect.append(dst) + + for node in to_connect: + # Connect the new store AccessNode to assure the store happens first + # If there isn't already a connnection between these two nodes + if not any(e.dst == node for e in forward_state.out_edges(new_store_node)): + forward_state.add_edge(new_store_node, None, node, None, Memlet()) + + # Another case for making sure data is stored before it is altered is when the map we save from writes itself to the data we want to save + # In this case this would depend on the codegen order of the tasklets within the map and is thus not safe + # Detect if this is the case + if map_exist: + # Check if this map exit writes to the data we want to save + if any( + isinstance(e.dst, nodes.AccessNode) and e.dst.data == forward_an.data + for e in forward_state.out_edges(map_exist)): + # Get the map entry of this map exit + tasklet_in_edges = forward_state.in_edges(assign_tasklet_node) + assert len(tasklet_in_edges) == 1 + tasklet_in_edge = tasklet_in_edges[0] + + # Safety check + if not isinstance(tasklet_in_edge.src, nodes.MapEntry): + raise AutoDiffException( + "The map exit writes to the data we want to save, but the storing strcuture is not what we expect" + ) + + # Get all the edges coming out of this specific in connector + collusion_edges = [ + e for e in forward_state.out_edges(tasklet_in_edge.src) + if e.src_conn == tasklet_in_edge.src_conn and e.dst != assign_tasklet_node + ] + + # We need to add an empty memlet from the new store tasklet to everything else that reads from that connector + for out_edge in collusion_edges: + forward_state.add_edge(assign_tasklet_node, None, out_edge.dst, None, Memlet()) + + return new_store_node, memlets_stack + + +def _connect_stored_data_to_target(bwd_generator: 'BackwardPassGenerator', forward_state: SDFGState, + backward_state: SDFGState, source_node: nodes.AccessNode, + forward_node: nodes.AccessNode, target_node: nodes.Node, memlets: List[Memlet], + starting_edge: dgraph.MultiConnectorEdge): + """ + Connect the source node to the sink target node (both in the backawrd state) through a set of maps using the parameter memelets. + We use the forward_sink_edge to track which maps to make this connection through. + :param source_node: the source node of the new memlet path + :param sink_node: the sink node of the new memlet path + :param memlets: the set of memlets to use for the edges in the path + :param forward_sink_edge: the sink edge connecting the original nodes in the forward state + """ + # First, if the stored data is not already in the sdfg descriptors, add it + # This is the case for NestedSDFGs + if source_node.data not in backward_state.sdfg.arrays: + # Get the data descriptor from the original sdfg + data_desc = copy.deepcopy(bwd_generator.sdfg.arrays[source_node.data]) + data_desc.transient = False # The stored data will be forwarded + backward_state.sdfg.add_datadesc(source_node.data, data_desc) + + # Get the memlet path from the forward state + all_edges = ad_utils.get_all_path_edges(forward_state, forward_node, starting_edge) + assert len(all_edges) > 0 + + # We will iterate and connect parent -> child + reversed_child_node = bwd_generator.reverse_map[target_node] + child_node = reversed_child_node + child_node_in_connector = all_edges[-1].dst_conn + + # Iterate through the maps in the path in reverse + for edge in reversed(all_edges): + edge_src = edge.src + if isinstance(edge_src, nodes.MapEntry): + # Get the correponding map exist + map_exit = _find_map_exist_for_map_entry(map_entry=edge_src, state=forward_state) + + # Use the lookup table to get the map entry in the backward state corresponding to this map exist in the forward state + # Sanity check: this map entry should already exist in the backward state + assert map_exit in bwd_generator.reverse_map + bwd_map_entry = bwd_generator.reverse_map[map_exit] + + # Get a new connector id + next_conn = bwd_map_entry.next_connector() + + # Add a new in connector to the mapexit + parent_node_in_connector = "IN_stored_" + source_node.data + "_" + next_conn + added = bwd_map_entry.add_in_connector(parent_node_in_connector) + assert added + + # Add a new out connector to the mapexit + parent_node_out_connector = "OUT_stored_" + source_node.data + "_" + next_conn + added = bwd_map_entry.add_out_connector(parent_node_out_connector) + assert added + + memlet_data = copy.deepcopy(memlets.pop(0)) + + # Add the edge with the corresponding memlet + backward_state.add_edge(bwd_map_entry, parent_node_out_connector, child_node, child_node_in_connector, + memlet_data) + + child_node = bwd_map_entry + child_node_in_connector = parent_node_in_connector + + if isinstance(edge_src, nodes.AccessNode): + # The connection from the stored data will be made here + assert edge_src == forward_node + memlet_data = copy.deepcopy(memlets.pop(0)) + + # Replicate the source stored node + replicated_source_node = copy.deepcopy(source_node) + backward_state.add_node(replicated_source_node) + + # Change the memlet data to read from the stored data and not the original data + memlet_data.data = replicated_source_node.data + + # Add the final connection to the source node + backward_state.add_edge(replicated_source_node, None, child_node, child_node_in_connector, memlet_data) + + # If this connection was made to a NestedSDFG and the forward node was a view, + # We need to change the strides in the data descriptor this points to + # Since the stored data is not a view + # For example, if the stride of A is 5 (because it points to a column in a 2d array), + # The stored data will only contain the row and the stride for it should be one + # This is only a problem if the view points to a NestedSDFG input, + # that expects a descriptor with the original view stride + if isinstance(child_node, nodes.NestedSDFG) and isinstance(forward_node.desc(bwd_generator.sdfg), dt.View): + # Get the strides of the stored data + stored_data_desc = bwd_generator.sdfg.arrays[source_node.data] + stored_strides = stored_data_desc.strides + + # Get the NestedSDFG input descriptor + input_desc = child_node.sdfg.arrays[child_node_in_connector] + + # Set the strides to be the last elements of the stored strides + # We take the last elements since we might add loop indices to the shape + # Sanity check the strides for this desc should be less than or equal to the stored strides + assert len(input_desc.strides) <= len(stored_strides) + input_desc.strides = stored_strides[-len(input_desc.shape):] + + # There should be the same number of memlets through the new path + assert len(memlets) == 0 + + +def _get_assign_tasklet(forward_state: SDFGState, + node: nodes.AccessNode, + stored_node: nodes.AccessNode, + last_edge: dgraph.MultiConnectorEdge, + loop_iterators: str, + cuda: bool = False): + """ + """ + # Create the assign tasklet + assign_tasklet_node_in_connector = "in_stored_" + node.data + assign_tasklet_node_out_connector = "out_stored_" + node.data + + # Create the memlet for the assignment + # This will be the same as the memlet going to the tasklet + assign_memlet_data = copy.deepcopy(last_edge.data) + param_dict = {} + memlet_access_iterators = [] + + # We check the incoming memlet volume + if assign_memlet_data.volume != 1: + # We need to add a map to iterate through the missing dimensions + # For this we will create an assign block containing a map + + # First, Get the missing dimensions + # Iterate through the subset + for element in last_edge.data.subset: + if str(element[0]) in last_edge.data.free_symbols: + # This is a symbol we will keep in the store memlet + memlet_access_iterators.append(str(element[0])) + else: + # This is a range tuple we need to add an iterator for + # Create a random new free symbol + free_symbol = forward_state.sdfg.find_new_symbol("si") + + # Add the new symbol here so that find_new_symbol doesn't return it again + forward_state.sdfg.add_symbol(free_symbol, dtypes.int64) + memlet_access_iterators.append(free_symbol) + param_dict.update({free_symbol: element}) + + # Build the memlets for input and output + in_state_access = ','.join(memlet_access_iterators) + input_memlet = Memlet(expr=f"{last_edge.data.data}[{in_state_access}]") + if loop_iterators: + output_memlet = Memlet(expr=f"{stored_node.data}[{loop_iterators},{in_state_access}]") + else: + output_memlet = Memlet(expr=f"{stored_node.data}[{in_state_access}]") + + assign_tasklet_node, map_entry, map_exit = forward_state.add_mapped_tasklet( + name=f"__store_{node.data}_assign_", + map_ranges=param_dict, + inputs={assign_tasklet_node_in_connector: input_memlet}, + code=f"{assign_tasklet_node_out_connector} = {assign_tasklet_node_in_connector}", + outputs={assign_tasklet_node_out_connector: output_memlet}, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=False) + + # Add the necessary connectors for external connections + map_entry.add_in_connector("IN_store_block") + map_exit.add_out_connector("OUT_store_block") + + # Update the internal edges to route through the new connectors + # Find and update the edge from map_entry to tasklet + for e in list(forward_state.out_edges(map_entry)): + if e.dst == assign_tasklet_node: + # Update the source connector to route through our external connector + forward_state.remove_edge(e) + forward_state.add_edge(map_entry, "OUT_store_block", assign_tasklet_node, + assign_tasklet_node_in_connector, e.data) + map_entry.add_out_connector("OUT_store_block") + break + + # Find and update the edge from tasklet to map_exit + for e in list(forward_state.in_edges(map_exit)): + if e.src == assign_tasklet_node: + # Update the destination connector to route through our external connector + forward_state.remove_edge(e) + forward_state.add_edge(assign_tasklet_node, assign_tasklet_node_out_connector, map_exit, + "IN_store_block", e.data) + map_exit.add_in_connector("IN_store_block") + break + + # Make sure this block is connected correctly + assign_block = map_entry + assign_block_in_connector = "IN_store_block" + return_node = map_exit + return_connector = "OUT_store_block" + else: + # Volume is 1, create a simple tasklet without a map + assign_tasklet_node = nodes.Tasklet( + label=f"__store_{node.data}_assign_", + inputs={assign_tasklet_node_in_connector}, + outputs={assign_tasklet_node_out_connector}, + code=f"{assign_tasklet_node_out_connector} = {assign_tasklet_node_in_connector}", + ) + + # Add it to the state + forward_state.add_node(assign_tasklet_node) + + assign_block = assign_tasklet_node + assign_block_in_connector = assign_tasklet_node_in_connector + return_node = assign_tasklet_node + return_connector = assign_tasklet_node_out_connector + + # Get the last map + last_map = last_edge.src + last_map_connector = last_edge.src_conn + + # Add the new edge from the last map entrance to the new assign block + forward_state.add_edge(last_map, last_map_connector, assign_block, assign_block_in_connector, assign_memlet_data) + return return_node, return_connector + + +def _find_map_exist_for_map_entry(map_entry: nodes.MapEntry, state: SDFGState) -> nodes.MapExit: + """ + Find the map exist that corresponds to the input map entry + """ + src_candidates = [node for node in state.nodes() if isinstance(node, nodes.MapExit) and node.map == map_entry.map] + if len(src_candidates) != 1: + # this shouldn't happen; if we are within a scope, the exit nodes + # for the scope should already exist in the backward pass + raise AutoDiffException("Invalid graph") + + return src_candidates[0] + + +def _get_symbol_upper_bound_from_loop(bwd_generator: 'DataForwardingbwd_generator', s: sp.Symbol, + loops: List[LoopRegion]) -> int: + """ + Given a symbol and a list of loops, get the upper bound of the symbol from the loops. + Raises an error if the symbol is not a loop index or the upper bound cannot be extracted correctly. + """ + # Get the symbol to match + if isinstance(s, (sp.Symbol, sp.Expr)): + # We don't want to match global SDFG symbols + loop_indices = {symb for symb in s.free_symbols if str(symb) not in bwd_generator.sdfg.free_symbols} + if len(loop_indices) != 1: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed correctly during storing") + loop_index = str(list(loop_indices)[0]) + elif isinstance(s, str): + # Convert the string to a symbolic expression and extract free symbols + try: + expr = sp.sympify(s) + except (sp.SympifyError, TypeError, ValueError) as e: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed as a symbolic expression: {e}") + + # We don't want to match global SDFG symbols + loop_indices = {symb for symb in expr.free_symbols if str(symb) not in bwd_generator.sdfg.free_symbols} + if len(loop_indices) != 1: + raise AutoDiffException(f"Symbol dimension {s} couldn't be parsed correctly during storing") + loop_index = str(list(loop_indices)[0]) + else: + raise AutoDiffException(f"Symbol dimension {s} is not a string and not a sympy symbol") + + # If the loop bound can be directly extracted from the interstate edges + if loop_index in bwd_generator.interstate_symbols: + loop_size = bwd_generator.interstate_symbols[loop_index] + else: + # Get the loop range for this symbol + loop_size = None + for l in loops: + # Convert the sympy symbol to string to check if it macthes the loop variable + if loop_index in l.loop_variable: + # Get the max loop range + start, end = ad_utils.extract_loop_region_info(l) + + # Check if the loop variable has a negative coefficient + # by extracting the coefficient from the affine expression + s_expr = sp.sympify(s) if isinstance(s, str) else s + # Find the actual symbol in the expression that matches loop_index by name + loop_symbol = None + for sym in s_expr.free_symbols: + if str(sym) == loop_index: + loop_symbol = sym + break + + # Extract the coefficient of the loop variable + if loop_symbol is not None: + coeff = s_expr.coeff(loop_symbol) + # If coefficient is negative we need to use smallest instead of largest + matched = coeff is not None and (coeff < 0) == True + else: + # Loop variable not found in expression + matched = False + smallest, largest = ad_utils.get_loop_end(start, end, l) + if not matched: + loop_size = largest + else: + loop_size = smallest + + if loop_size is None: + raise AutoDiffException( + f"Can't figure out how to save the data inside: {l.label} because of its symbol shape {s}") + + # We will call this function recusrively until loop size is numeric or it is a global SDFG symbol + if ad_utils.shape_has_symbols_to_replace(bwd_generator.sdfg, loop_size): + loop_size, _ = _get_symbol_upper_bound_from_loop(bwd_generator, loop_size, loops) + return loop_size, loop_index + + +def _get_all_enclosing_loops(forward_state: SDFGState) -> List[LoopRegion]: + """ + Check if this state will be executed several times within a loop. + We check if any of the parents of this state is a loop region. + """ + all_loops = [] + parent = forward_state.parent_graph + while parent is not None: + if isinstance(parent, LoopRegion): + all_loops.append(parent) + parent = parent.parent_graph + return all_loops diff --git a/dace/autodiff/implementations/__init__.py b/dace/autodiff/implementations/__init__.py new file mode 100644 index 0000000000..21feb8e552 --- /dev/null +++ b/dace/autodiff/implementations/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Backward Pass Implementations for SDFG Elements. + +This package provides backward (gradient) implementations for various SDFG node types. +Each implementation defines how to compute gradients for specific operations. + +Implementation Categories +------------------------- +1. **DaCe Nodes** (dace_nodes.py): + - Core SDFG elements: Tasklet, MapEntry, AccessNode, etc. + - Fundamental building blocks for all DaCe programs + - Registered in DaceNodeBackwardImplementations + +2. **DaCe Reduction Nodes** (dace_reduction_nodes.py): + - Reduction operations: Sum, Max, Min + - Registered using @autoregister decorator + +3. **ONNX Operations** (onnx_ops.py): + - ONNX-specific operations from dace.libraries.onnx + - Neural network layers and operators + - Supports ONNX model differentiation + +4. **PyTorch Operations** (pytorch_ops.py): + - Operations using PyTorch CUDA kernels + - Depthwise convolution backward pass +""" + +import dace.autodiff.implementations.dace_reduction_nodes +from dace.autodiff.implementations.dace_nodes import DaceNodeBackwardImplementations + +# ONNX ops are optional +try: + import dace.autodiff.implementations.onnx_ops +except ImportError: + pass + +# PyTorch ops are optional +try: + import dace.autodiff.implementations.pytorch_ops +except ImportError: + pass + +__all__ = [ + "DaceNodeBackwardImplementations", +] diff --git a/dace/autodiff/implementations/dace_nodes.py b/dace/autodiff/implementations/dace_nodes.py new file mode 100644 index 0000000000..2439e08137 --- /dev/null +++ b/dace/autodiff/implementations/dace_nodes.py @@ -0,0 +1,487 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" + Class for defining the reversal of pure SDFG nodes: AccessNode, Tasklet, MapEntry/Exit, NestedSDFG + Each method should return a tuple (reversed_node, BackwardResult) +""" +import ast +import collections +import copy +import numbers +import astunparse +import sympy as sp +from typing import List, Tuple + +# DaCe imports +import dace +import dace.sdfg.nodes as nodes +from dace import dtypes +from dace.data import Reference, Structure +from dace.sdfg import SDFGState +from dace.data import find_new_name + +# Autodiff imports +from dace.autodiff.base_abc import BackwardResult, AutoDiffException +import dace.autodiff.utils as ad_utils + + +class DaceNodeBackwardImplementations: + + def __init__(self, backward_pass_generator: 'BackwardPassGenerator'): + self.bwd_engine = backward_pass_generator + pass + + def _reverse_NestedSDFG( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.NestedSDFG, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + reverse_nsdfg = dace.SDFG(node.sdfg.name + "_backward") + + gen = self.bwd_engine.create_child_generator( + sdfg=node.sdfg, + given_gradients=given_gradients, + required_gradients=required_gradients, + backward_sdfg=reverse_nsdfg, + ) + backward_result, _, backward_input_arrays = gen.backward() + + # we need to defer add edges until after the arrays have been added because creation of the nested + # sdfg fails otherwise + deferred_edges = [] + + inputs = set(backward_result.given_grad_names[name] for name in sorted(given_gradients)) + # loop through the arrays that we need from the forward pass + for name, desc in sorted(backward_input_arrays.items()): + # if the name is not already passed to the reverse SDFG node ... + if name not in required_gradients and name not in node.in_connectors: + # ... this array needs to be forwarded out of the forward SDFG (i.e. it is an intermediate value) + # 1) add it to the current SDFG, and to self.bwd_engine.backward_input_arrays + # 2) add an out connector to the forward nested SDFG, add a write node to the current state, and an edge + # from the output to there + # 3) add a read node to the backward state, and an edge into it + + desc = node.sdfg.arrays[name] + + # if the original view node is in the in-connector, no need to connect it, continue + # if forwarded_name in node.in_connectors: + # continue + + # (1) + new_name = find_new_name(name + "_forwarded", self.bwd_engine.sdfg.arrays) + if new_name in self.bwd_engine.sdfg.arrays or new_name in self.bwd_engine.backward_input_arrays: + raise AutoDiffException( + "Attempted to create array with name '{}', but it already existed".format(new_name)) + + self.bwd_engine.sdfg.add_datadesc(new_name, copy.deepcopy(desc)) + self.bwd_engine.backward_input_arrays[new_name] = copy.deepcopy(desc) + + if self.bwd_engine.separate_sdfgs: + to_add = copy.deepcopy(desc) + to_add.transient = False + self.bwd_engine.backward_sdfg.add_datadesc(new_name, to_add) + + # (2) + node.sdfg.arrays[name].transient = False + added = node.add_out_connector(name, force=True) + assert added + write = forward_state.add_write(new_name) + forward_state.add_edge(node, name, write, None, self.bwd_engine.sdfg.make_array_memlet(new_name)) + + # (3) + read = backward_state.add_read(new_name) + deferred_edges.append( + dict(u=read, + u_connector=None, + v_connector=name, + memlet=self.bwd_engine.backward_sdfg.make_array_memlet(new_name))) + inputs.add(name) + else: + inputs.add(name) + + outputs = set(backward_result.required_grad_names[name] for name in required_gradients) + + for inp in inputs: + if inp in reverse_nsdfg.arrays: + reverse_nsdfg.arrays[inp].transient = False + for outp in outputs: + if outp in reverse_nsdfg.arrays: + reverse_nsdfg.arrays[outp].transient = False + # Create the sdfg and return it + nsdfg = backward_state.add_nested_sdfg( + reverse_nsdfg, + inputs=inputs, + outputs=outputs, + ) + + # If any input connectors point to symbols + for conn, _ in nsdfg.in_connectors.items(): + if conn in nsdfg.sdfg.symbols: + # We need to add a new symbol and create a mapping + new_symbol = find_new_name(conn, nsdfg.sdfg.symbols) + nsdfg.sdfg.add_symbol(new_symbol, nsdfg.sdfg.symbols[conn]) + nsdfg.sdfg.replace(conn, new_symbol) + nsdfg.symbol_mapping[new_symbol] = conn + # Remove it from the symbol mapping too + if conn in nsdfg.symbol_mapping: + del nsdfg.symbol_mapping[conn] + for edge_args in deferred_edges: + edge_args["v"] = nsdfg + backward_state.add_edge(**edge_args) + + return nsdfg, BackwardResult(required_grad_names=backward_result.required_grad_names, + given_grad_names=backward_result.given_grad_names) + + def _reverse_AccessNode( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.AccessNode, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + + desc = self.bwd_engine.sdfg.arrays[node.data] + if isinstance(desc, Reference): + raise AutoDiffException(f"AccessNode '{node.data}' points to a Reference, which is not yet supported") + if isinstance(desc, Structure): + raise AutoDiffException(f"AccessNode '{node.data}' points to a Structure, which is not yet supported") + + rev = nodes.AccessNode(self.bwd_engine.array_grad_name(node.data)) + # We want all gradient arrays to be initialized to zero + # This is important for correct gradient accumulation + rev.setzero = True + backward_state.add_node(rev) + required_grad_names = {None: None} + given_grad_names = {None: None} + + if "views" in node.in_connectors: + required_grad_names = {"views": "views"} + if "views" in node.out_connectors: + given_grad_names = {"views": "views"} + + return rev, BackwardResult(required_grad_names=required_grad_names, given_grad_names=given_grad_names) + + def _reverse_MapEntry( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.MapEntry, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + + required_grad_names = {n: ad_utils.invert_map_connector(n) for n in required_gradients} + given_grad_names = {n: ad_utils.invert_map_connector(n) for n in given_gradients} + result = BackwardResult(required_grad_names=required_grad_names, given_grad_names=given_grad_names) + rev = nodes.MapExit(self.bwd_engine.reverse_map[node.map]) + + for _, conn in sorted(given_grad_names.items()): + added = rev.add_in_connector(conn) + assert added + + for _, conn in sorted(required_grad_names.items()): + added = rev.add_out_connector(conn) + assert added + + backward_state.add_node(rev) + return rev, result + + def _reverse_MapExit( + self, + forward_state: SDFGState, + backward_state: SDFGState, + node: nodes.MapExit, + given_gradients: List[str], + required_gradients: List[str], + ): + self.bwd_engine.reverse_map[node.map] = copy.deepcopy(node.map) + + rev = nodes.MapEntry(self.bwd_engine.reverse_map[node.map]) + for conn in sorted(node.in_connectors): + added = rev.add_in_connector(conn) + assert added + + for conn in sorted(node.out_connectors): + added = rev.add_out_connector(conn) + assert added + + backward_state.add_node(rev) + # yapf: disable + return ( + rev, + BackwardResult(required_grad_names={ + n: ad_utils.invert_map_connector(n) + for n in required_gradients + }, + given_grad_names={ + n: ad_utils.invert_map_connector(n) + for n in given_gradients + }), + ) + # yapf: enable + + def _reverse_Tasklet( + self, + state: SDFGState, + backward_state: SDFGState, + tasklet: nodes.Tasklet, + given_gradients: List[str], + required_gradients: List[str], + ) -> Tuple[nodes.Node, BackwardResult]: + if tasklet.language is not dtypes.Language.Python: + raise AutoDiffException("Expected tasklet with language Python, got language {}".format(tasklet.language)) + + # tasklets should have scalar inputs (can be relaxed) + for _, _, _, _, memlet in state.in_edges(tasklet): + if memlet.data is not None: + try: + ad_utils.is_int_eq_value(memlet.subset.num_elements(), 1) + except AutoDiffException as e: + raise AutoDiffException( + "Autodiff only supported for tasklets with scalar inputs and outputs") from e + + for _, _, _, _, memlet in state.out_edges(tasklet): + if memlet.data is not None: + try: + ad_utils.is_int_eq_value(memlet.subset.num_elements(), 1) + except AutoDiffException as e: + raise AutoDiffException( + "Autodiff only supported for tasklets with scalar inputs and outputs") from e + + code_str = tasklet.code.as_string + + # check if this is a conditional tasklet + if self.bwd_engine._conditional_tasklet(tasklet): + # we want to extract the if and else expressions and pass them to sympy + if_expression, else_expression, conditional = ad_utils.extract_conditional_expressions(tasklet) + + if_code, if_rev_inputs, if_rev_outputs, if_result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, if_expression, state, tasklet, given_gradients, required_gradients) + + if else_expression: + else_code, else_rev_inputs, else_rev_outputs, else_result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, else_expression, state, tasklet, given_gradients, required_gradients) + assert else_rev_inputs == if_rev_inputs + assert if_rev_outputs == else_rev_outputs + assert else_result == if_result + + # prepare the tasklet code depending on the conditional type + # add the same conditional to the if_code + # first, add indentation + if_code = if_code.replace("\n", "\n\t") + if_code = f"if {conditional}:\n{if_code}" + + # add the conditional to the in connectors + if_rev_inputs.add(conditional) + joint_code = if_code + + if ":" not in code_str: + # only an if in the original code + assert else_expression + else_code = else_code.replace("\n", "\n\t") + else_code = f"else:\n{else_code}" + joint_code = f"{if_code}\n{else_code}" + + # in case there are no out_connectors, we will zero out the assigned-to AccessNode + if len(if_rev_outputs) == 0: + if_rev_outputs = {"__zero_out_conn__"} + + rev = nodes.Tasklet("_" + tasklet.label + "_reverse_", + inputs=if_rev_inputs, + outputs=if_rev_outputs, + code=joint_code, + debuginfo=tasklet.debuginfo) + + result = if_result + else: + code, rev_inputs, rev_outputs, result = self._differentiate_code_symbolically( + self.bwd_engine.sdfg, code_str, state, tasklet, given_gradients, required_gradients) + rev = nodes.Tasklet("_" + tasklet.label + "_reverse_", + inputs=rev_inputs, + outputs=rev_outputs, + code=code, + debuginfo=tasklet.debuginfo) + backward_state.add_node(rev) + return rev, result + + def _differentiate_code_symbolically( + self, + sdfg: dace.SDFG, + code_str: str, + forward_state: SDFGState, + tasklet: nodes.Tasklet, + given_gradients: List[str], + required_gradients: List[str], + ): + """Performs symbolic differentiation on tasklet code to generate the backward-pass tasklet. + + This method uses SymPy to symbolically differentiate expressions in a tasklet's code, + applying the chain rule to compute gradients with respect to input variables. + + :param sdfg: The parent SDFG containing the tasklet. + :param code_str: Code string from the tasklet to differentiate. + :param forward_state: The SDFGState containing the forward tasklet. + :param tasklet: The forward tasklet node being differentiated. + :param given_gradients: List of output connector names for which gradients are provided (∂L/∂output). + :param required_gradients: List of input connector names for which gradients must be computed (∂L/∂input). + :return: A 4-tuple containing (code, rev_inputs, rev_outputs, result) where code is the generated + Python code for the backward tasklet, rev_inputs is the set of input connector names, + rev_outputs is the set of output connector names, and result is the BackwardResult mapping. + :raises AutoDiffException: If symbolic differentiation fails (e.g., non-differentiable operations, + unexpected graph structure, missing input edges). + + .. note:: + - Uses SymPy's symbolic differentiation and common subexpression elimination (CSE) + - Supports indexed array accesses (e.g., A[i, j]) via IndexedBase + - Handles constant assignments by zeroing gradients + - Gradient names are generated with "_gradient" suffix to avoid conflicts + - SDFG-level symbols are excluded from backward tasklet inputs + - Type casting ensures gradient types match forward pass data types + - Applies chain rule: ∂L/∂input = ∂L/∂output * (∂output/∂input) + + Example:: + + Forward tasklet: ``y = x * x + 2 * x`` + Given gradient: dy (∂L/∂y) + Required gradient: dx (∂L/∂x) + Generated code: ``dx_gradient = dy_gradient * (2*x + 2)`` + """ + output_exprs, indexed_objects_map = ad_utils.code_to_exprs(code_str, tasklet, list(sdfg.symbols.keys())) + + # for each output that an input is used in, there will be an entry for the expression of the + # grad in this list in the final code snippet. When we generate the final code for the + # reverse tasklet, we need to add them all up. + rev_code = collections.defaultdict(list) + + # the outputs of the reversed nodes are the grads of inputs of the original node + rev_outputs = set() + rev_inputs = set() + + result = BackwardResult(required_grad_names={}, given_grad_names={}) + + # symbol generator to use for CSE + symbol_generator = sp.numbered_symbols() + + code = "" + + for output_conn in sorted(given_gradients): + + # special case for conditional tasklets with constant assignment + if len(required_gradients) == 0: + # for this we need to assing a zero to the gradient output + # pick a name for the input gradient + rev_input_grad_name = find_new_name(output_conn + "_gradient", rev_inputs) + result.given_grad_names[output_conn] = rev_input_grad_name + + # zero out the gradient + code = f"\n__zero_out_conn__ = 0.0" + rev_outputs = {} + rev_inputs = {rev_input_grad_name} + + # for each output_conn... + for inp in sorted(required_gradients): + # ...add the code to generate {inp}_grad + + if inp not in result.required_grad_names: + # pick a name for the gradient + rev_output_grad_name = find_new_name(inp + "_gradient", rev_outputs) + result.required_grad_names[inp] = rev_output_grad_name + rev_outputs.add(rev_output_grad_name) + else: + rev_output_grad_name = result.required_grad_names[inp] + + output_expr = output_exprs[output_conn] + # if the expression is a constant assignment, we need to cast the float to the sympy equivalent + if isinstance(output_expr, numbers.Real): + output_expr = sp.Float(output_expr) + + # We need to prepare the w.r.t expression + if inp in indexed_objects_map: + # if the input is an indexed object, we need to create the sympy expression + indexed_base = sp.IndexedBase(inp) + idx_objects = [sp.Idx(index) for index in indexed_objects_map[inp]] + inp_expr = indexed_base[tuple(idx_objects)] + else: + # if the input is not an indexed object, we can just use it as is + inp_expr = sp.symbols(inp) + + # symbolically differentiate the output w.r.t inp + diff_expr = output_expr.diff(inp_expr) + + # do common subexpression elimination + sub_expressions, diff_expr = sp.cse(diff_expr, symbols=symbol_generator) + + diff_expr = diff_expr[0] + + if diff_expr.atoms(sp.Derivative): + # the final result contains a call to sp.Derivative + raise AutoDiffException("Unable to symbolically differentiate expression: {}".format( + diff_expr.expr)) + + if output_conn not in result.given_grad_names: + # pick a name for the input gradient + rev_input_grad_name = find_new_name(output_conn + "_gradient", rev_inputs) + result.given_grad_names[output_conn] = rev_input_grad_name + else: + rev_input_grad_name = result.given_grad_names[output_conn] + + input_symbols = diff_expr.free_symbols\ + .union(s for _, e in sub_expressions for s in e.free_symbols)\ + .difference(e for e, _ in sub_expressions) + + string_symbols = {str(symb) for symb in input_symbols} + + # If there are any symbols that are defined at the global SDFG scope + # We do not need to add these as inputs to the reverse tasklet + string_symbols = string_symbols.difference(set(sdfg.symbols.keys())) + rev_inputs |= string_symbols | {rev_input_grad_name} + + diff_code_str = "{input} * ({diff_expr})".format(input=rev_input_grad_name, diff_expr=str(diff_expr)) + # small hack: our heaviside is lowercase + diff_code_str = diff_code_str.replace("Heaviside", "heaviside") + + diff_code_str = astunparse.unparse(ad_utils.SympyCleaner().visit(ast.parse(diff_code_str))) + + sub_expression_code_strs = "\n".join(f"{target} = {expression}" + for target, expression in sub_expressions) + + # get the the final type of the gradient: this is just the type of the input connector we creating the + # gradient for + cands = list(forward_state.in_edges_by_connector(tasklet, inp)) + if len(cands) != 1: + raise AutoDiffException(f"Unexpected graph structure, could not find input edge for connector {inp}" + f" on tasklet {tasklet}") + + converted_code = ad_utils.cast_consts_to_type(diff_code_str, sdfg.arrays[cands[0].data.data].dtype) + converted_code = converted_code.replace("\n", " ") + + converted_sub_expressions = ad_utils.cast_consts_to_type(sub_expression_code_strs, + sdfg.arrays[cands[0].data.data].dtype) + + # If there is indirection in the input + if inp in indexed_objects_map: + # We need to have indirection of the output container in the backward + output_code = rev_output_grad_name + "[" + " , ".join(indexed_objects_map[inp]) + "]" + + # We also need to add the indices as connectors so that they are forwarded from the forward pass + for idx in indexed_objects_map[inp]: + if idx not in rev_inputs: + # This needs to be available in the forward pass in the first place + if idx not in tasklet.in_connectors: + raise AutoDiffException( + f"Expected index {idx} to be an input connector of the tasklet {tasklet}, " + f"but it is not. This is required for the backward pass to work correctly.") + rev_inputs.add(idx) + else: + output_code = rev_output_grad_name + + code += converted_sub_expressions + "\n" + rev_code[output_code].append(converted_code) + + for output, exprs in sorted(rev_code.items()): + code += "\n" + output + " = " + " + ".join(exprs) + + return code, rev_inputs, rev_outputs, result diff --git a/dace/autodiff/implementations/dace_reduction_nodes.py b/dace/autodiff/implementations/dace_reduction_nodes.py new file mode 100644 index 0000000000..05232acbbf --- /dev/null +++ b/dace/autodiff/implementations/dace_reduction_nodes.py @@ -0,0 +1,307 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe Library Node Backward Pass Implementations for Automatic Differentiation. + +This module provides backward pass implementations for DaCe standard library nodes +in the automatic differentiation system. Each class implements the BackwardImplementation +interface to compute gradients for specific library operations during reverse-mode +automatic differentiation. + +""" + +import copy +import typing + +# DaCe core imports +import dace +import dace.dtypes as dtypes +import dace.libraries.standard.nodes +from dace import SDFGState, SDFG, Memlet +from dace.sdfg.nodes import Node + +# DaCe frontend imports +from dace.frontend.operations import detect_reduction_type +from dace.registry import autoregister_params + +# Autodiff imports +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException + +# Utility imports +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + + +@autoregister_params(node_type=dace.libraries.standard.nodes.Reduce, name="pure") +class ReverseReduce(BackwardImplementation): + """Backward implementation for DaCe Reduce library nodes. + + Supports Sum, Max, and Min reduction operations. The backward pass distributes + gradients appropriately based on the reduction type: + - Sum: Broadcasts gradients uniformly across reduced dimensions + - Max/Min: Routes gradients only to positions that achieved the extremal value + """ + + @staticmethod + def backward_can_be_applied(node: Node, state: SDFGState, sdfg: SDFG) -> bool: + """Check if backward pass can be applied to this reduction node. + + :param node: The reduction node to check. + :param state: The SDFG state containing the node (unused but required by interface). + :param sdfg: The SDFG containing the state (unused but required by interface). + :return: True if backward pass can be applied, False otherwise. + """ + reduction_type = detect_reduction_type(node.wcr) + if reduction_type not in (dtypes.ReductionType.Sum, dtypes.ReductionType.Max, dtypes.ReductionType.Min): + return False + + return True + + @staticmethod + def backward(forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], + required_gradients: typing.List[typing.Optional[str]]) -> typing.Tuple[Node, BackwardResult]: + """Generate the backward pass for a reduction node. + + :param forward_node: The forward reduction node. + :param context: The backward pass context. + :param given_gradients: List of gradient names provided to this node. + :param required_gradients: List of gradient names required by this node. + :return: Tuple of the backward node and the backward result. + :raises AutoDiffException: If the node has invalid number of edges. + """ + reduction_type = detect_reduction_type(forward_node.wcr) + + if len(given_gradients) != 1: + raise AutoDiffException(f"Invalid SDFG: reduce node {forward_node} should have exactly one output edge, " + f"got {len(given_gradients)} output gradients") + + if len(required_gradients) != 1: + raise AutoDiffException(f"Invalid SDFG: reduce node {forward_node} should have exactly one input edge, " + f"got {len(required_gradients)} input gradients") + + input_name = next(iter(required_gradients)) + in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) + + output_name = next(iter(given_gradients)) + out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) + + all_axes: typing.List[int] = list(range(len(in_desc.shape))) + reduce_axes: typing.List[int] = all_axes if forward_node.axes is None else forward_node.axes + non_reduce_axes: typing.List[int] = [i for i in all_axes if i not in reduce_axes] + + result = BackwardResult.empty() + + return ReverseReduce._backward_reduction(forward_node, context, result, reduction_type, input_name, output_name, + in_desc, out_desc, all_axes, non_reduce_axes) + + @staticmethod + def _backward_reduction(forward_node: Node, context: BackwardContext, result: BackwardResult, + reduction_type: dtypes.ReductionType, input_name: str, output_name: str, in_desc, out_desc, + all_axes: typing.List[int], + non_reduce_axes: typing.List[int]) -> typing.Tuple[Node, BackwardResult]: + """Backward pass for Sum/Max/Min reductions. + + - Sum: Broadcasts gradients uniformly across reduced dimensions + - Max/Min: Routes gradients to positions that achieved the extremal value, + split equally among tied values + + :param forward_node: The forward reduction node. + :param context: The backward pass context. + :param result: The backward result to populate. + :param reduction_type: The type of reduction (Sum, Max, or Min). + :param input_name: Name of the input connector. + :param output_name: Name of the output connector. + :param in_desc: Input data descriptor. + :param out_desc: Output data descriptor. + :param all_axes: List of all axes indices. + :param non_reduce_axes: List of axes not being reduced. + :return: Tuple of the nested SDFG node and the backward result. + """ + is_extremal = reduction_type in (dtypes.ReductionType.Max, dtypes.ReductionType.Min) + type_name = { + dtypes.ReductionType.Sum: "sum", + dtypes.ReductionType.Max: "max", + dtypes.ReductionType.Min: "min" + }[reduction_type] + + sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") + + rev_input_conn_name = "input_gradient" + rev_output_conn_name = "output_gradient" + + result.required_grad_names[output_name] = rev_output_conn_name + result.given_grad_names[input_name] = rev_input_conn_name + + sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype, strides=out_desc.strides) + sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype, strides=in_desc.strides) + + nsdfg_inputs = {rev_input_conn_name} + + if is_extremal: + extremal_conn_name = f"input_{type_name}" + extremal_idx_conn_name = f"input_{type_name}_idx" + sdfg.add_array(extremal_conn_name, shape=out_desc.shape, dtype=out_desc.dtype, strides=out_desc.strides) + sdfg.add_array(extremal_idx_conn_name, shape=in_desc.shape, dtype=in_desc.dtype, strides=in_desc.strides) + nsdfg_inputs.update({extremal_conn_name, extremal_idx_conn_name}) + + # Add transient array to count matching elements per output position + count_arr_name = f"_{type_name}_count" + sdfg.add_array(count_arr_name, shape=out_desc.shape, dtype=out_desc.dtype, transient=True) + + reduce_all_axes = forward_node.axes is None or set(range(len(in_desc.shape))) == set(forward_node.axes) + + if is_extremal: + # Two-state approach for max/min: + # State 1: Count elements matching extremal value + # State 2: Compute normalized gradient + + count_state = sdfg.add_state(f"count_{type_name}_{id(forward_node)}") + grad_state = sdfg.add_state(f"grad_{type_name}_{id(forward_node)}") + sdfg.add_edge(count_state, grad_state, dace.InterstateEdge()) + + # State 1: Count matching elements + count_memlet = Memlet.simple(count_arr_name, + "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes), + wcr_str="lambda x, y: x + y") + extremal_val_memlet_count = Memlet.simple( + extremal_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + extremal_idx_memlet_count = Memlet.simple(extremal_idx_conn_name, ",".join("i" + str(i) for i in all_axes)) + + _, _, count_exit = count_state.add_mapped_tasklet( + f"_count_{type_name}_matches_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, { + "__extremal_val": extremal_val_memlet_count, + "__extremal_val_idx": extremal_idx_memlet_count + }, + "__count = 1.0 if __extremal_val == __extremal_val_idx else 0.0", {"__count": count_memlet}, + external_edges=True) + + # Set count array to zero before accumulation + count_out_edges = count_state.out_edges(count_exit) + if len(count_out_edges) == 1: + count_out_node = count_out_edges[0].dst + if isinstance(count_out_node, dace.nodes.AccessNode): + count_out_node.setzero = True + + # State 2: Compute normalized gradient (grad / count) + reduction_memlet = Memlet.simple( + rev_input_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + reverse_reduction_memlet = Memlet.simple(rev_output_conn_name, + ",".join("i" + str(i) for i in all_axes), + wcr_str="lambda x, y: x + y") + extremal_val_memlet = Memlet.simple( + extremal_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + extremal_idx_memlet = Memlet.simple(extremal_idx_conn_name, ",".join("i" + str(i) for i in all_axes)) + count_read_memlet = Memlet.simple( + count_arr_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + + tasklet_inputs = { + "__in": reduction_memlet, + "__extremal_val": extremal_val_memlet, + "__extremal_val_idx": extremal_idx_memlet, + "__count": count_read_memlet + } + tasklet_code = "__out = __in / __count if __extremal_val == __extremal_val_idx else 0" + + _, _, exit_map = grad_state.add_mapped_tasklet(f"_{type_name}_grad_" + + str(reduction_type).replace(".", "_") + "_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, + tasklet_inputs, + tasklet_code, {"__out": reverse_reduction_memlet}, + external_edges=True) + + state = grad_state + else: + # Sum reduction: simple broadcast + state = sdfg.add_state(f"block_{id(forward_node)}") + reduction_memlet = Memlet.simple( + rev_input_conn_name, "0" if reduce_all_axes else ",".join("i" + str(i) for i in non_reduce_axes)) + reverse_reduction_memlet = Memlet.simple(rev_output_conn_name, + ",".join("i" + str(i) for i in all_axes), + wcr_str="lambda x, y: x + y") + tasklet_inputs = {"__in": reduction_memlet} + tasklet_code = "__out = __in" + + _, _, exit_map = state.add_mapped_tasklet(f"_{type_name}_grad_" + str(reduction_type).replace(".", "_") + + "_", { + "i" + str(i): "0:{}".format(shape) + for i, shape in enumerate(in_desc.shape) + }, + tasklet_inputs, + tasklet_code, {"__out": reverse_reduction_memlet}, + external_edges=True) + + nsdfg = context.backward_state.add_nested_sdfg(sdfg, nsdfg_inputs, {rev_output_conn_name}) + + out_edges = state.out_edges(exit_map) + if len(out_edges) != 1: + raise AutoDiffException(f"Expected exactly one output edge from map exit, got {len(out_edges)}") + out_edge = out_edges[0] + out_node = out_edge.dst + if not isinstance(out_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as output, got {type(out_node)}") + out_node.setzero = True + + if not is_extremal: + return nsdfg, result + + backward_state = context.backward_state + fwd_in_edges = context.forward_state.in_edges(forward_node) + if len(fwd_in_edges) != 1: + raise AutoDiffException(f"Expected exactly one input edge to forward node, got {len(fwd_in_edges)}") + fwd_in_edge = fwd_in_edges[0] + fwd_in_node = fwd_in_edge.src + if not isinstance(fwd_in_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as input source, got {type(fwd_in_node)}") + + # Register forward input array for data forwarding (in case it's overwritten) + if fwd_in_node.data not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[fwd_in_node.data]) + context.backward_generator.backward_input_arrays[fwd_in_node.data] = data_desc + + bwd_read = backward_state.add_read(fwd_in_node.data) + backward_state.add_edge(bwd_read, None, nsdfg, extremal_idx_conn_name, copy.deepcopy(fwd_in_edge.data)) + + if isinstance(context.forward_sdfg.arrays[fwd_in_node.data], (dace.data.View, dace.data.ArrayView)): + in_edge = context.forward_state.in_edges(fwd_in_node) + if len(in_edge) != 1: + raise AutoDiffException(f"Expected exactly one input edge to view node, got {len(in_edge)}") + in_edge = in_edge[0] + in_node = in_edge.src + if isinstance(in_node, dace.nodes.AccessNode): + if isinstance(context.forward_sdfg.arrays[in_node.data], (dace.data.View, dace.data.ArrayView)): + raise AutoDiffException(f"Nested views are not supported: {in_node.data}") + bwd_in_read = backward_state.add_read(in_node.data) + backward_state.add_edge(bwd_in_read, None, bwd_read, "views", copy.deepcopy(in_edge.data)) + + fwd_out_edges = context.forward_state.out_edges(forward_node) + if len(fwd_out_edges) != 1: + raise AutoDiffException(f"Expected exactly one output edge from forward node, got {len(fwd_out_edges)}") + fwd_out_edge = fwd_out_edges[0] + fwd_out_node = fwd_out_edge.dst + if not isinstance(fwd_out_node, dace.nodes.AccessNode): + raise AutoDiffException(f"Expected AccessNode as output destination, got {type(fwd_out_node)}") + + # Register forward output array for data forwarding (in case it's overwritten) + if fwd_out_node.data not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[fwd_out_node.data]) + context.backward_generator.backward_input_arrays[fwd_out_node.data] = data_desc + + bwd_out_read = backward_state.add_read(fwd_out_node.data) + backward_state.add_edge(bwd_out_read, None, nsdfg, extremal_conn_name, copy.deepcopy(fwd_out_edge.data)) + + if isinstance(context.forward_sdfg.arrays[fwd_out_node.data], (dace.data.View, dace.data.ArrayView)): + out_edge = context.forward_state.out_edges(fwd_out_node) + if len(out_edge) != 1: + raise AutoDiffException(f"Expected exactly one output edge from view node, got {len(out_edge)}") + out_edge = out_edge[0] + out_node = out_edge.dst + if isinstance(out_node, dace.nodes.AccessNode): + if isinstance(context.forward_sdfg.arrays[out_node.data], (dace.data.View, dace.data.ArrayView)): + raise AutoDiffException(f"Nested views are not supported: {out_node.data}") + bwd_in_read = backward_state.add_read(out_node.data) + backward_state.add_edge(bwd_in_read, None, bwd_out_read, "views", copy.deepcopy(out_edge.data)) + + return nsdfg, result diff --git a/dace/autodiff/implementations/onnx_ops.py b/dace/autodiff/implementations/onnx_ops.py new file mode 100644 index 0000000000..e04f86c728 --- /dev/null +++ b/dace/autodiff/implementations/onnx_ops.py @@ -0,0 +1,1045 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Backward Pass Implementations for Automatic Differentiation. + +This module provides backward pass implementations for ONNX operations in the DaCe autodiff +system. Each class implements the BackwardImplementation interface to compute gradients +for specific ONNX operations during reverse-mode automatic differentiation. + +The implementations handle various ONNX operations including: +- Mathematical operations (Einsum, Clip, Softmax, etc.) +- Neural network layers (Conv, LayerNormalization, etc.) +- Pooling operations (MaxPool, GlobalAveragePool) +- Utility operations (Transpose, Where, etc.) +""" + +import copy +import itertools +from typing import List, Optional, Tuple, Dict, Union + +import numpy as np + +# DaCe core imports +import dace +from dace.frontend.common import einsum +import dace.libraries +from dace.registry import autoregister_params +from dace import nodes as nd + +# ONNX-specific imports +import dace.libraries.onnx as donnx +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.onnx.op_implementations.linalg_ops import PureEinsum +from dace.transformation.onnx.replacement import onnx_constant_or_none + +# Autodiff imports +import dace.autodiff.utils as butils +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult + +# Utility imports +from dace.sdfg.utils import in_desc_with_name + + +def reverse_einsum_wrt_input(forward_node: 'donnx.nodes.onnx_op.ONNXOp', input_name: str) -> Tuple[List[str], str]: + """Produce the einsum string that computes the gradient of forward_node w.r.t. input_name. + + .. note:: + There is an edge case we currently don't handle (can be implemented though). + Something like 'ii->i' would become 'i->ii'. This is invalid because 'i' is repeated in the output. + + :param forward_node: The einsum node to reverse. + :param input_name: The connector on the forward node to produce the gradient computation for. + :return: Tuple of (list of forward node connectors required as inputs, einsum string). + The first parameter of the produced einsum string is implicitly the grad of Output. + """ + + _, input_idx = donnx.parse_variadic_param(input_name) + parser = einsum.EinsumParser(forward_node.equation) + + backward_input_expressions = [parser.output] + parser.inputs[:input_idx] + parser.inputs[input_idx + 1:] + backward_input_arrays = [ + f"Inputs__{i}" for i in itertools.chain(range(input_idx), range(input_idx + 1, len(parser.inputs))) + ] + + einsum_str = f"{','.join(backward_input_expressions)}->{parser.inputs[input_idx]}" + return backward_input_arrays, einsum_str + + +@autoregister_params(op="Einsum", name="default") +class DefaultEinsumBackward(BackwardImplementation): + """Backward implementation for ONNX Einsum operation. + + The symbolic autodiff can automatically derive matmuls, but the produced maps are more difficult to optimize. + This implementation provides a more efficient ONNX-based backward pass. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return PureEinsum.forward_can_be_applied(node, state, sdfg) + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + # setup arrays + output_desc = butils.forward_out_desc_with_name(forward_node, context, "Output") + result = BackwardResult.empty() + result.given_grad_names["Output"] = butils.add_backward_desc(nsdfg, context.forward_sdfg, output_desc, "Output") + access_output_grad = nstate.add_read(result.given_grad_names["Output"]) + + def create_access_node(connector: str) -> nd.AccessNode: + nsdfg.add_datadesc(connector, + copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, connector))) + return nstate.add_read(connector) + + # the forward inputs we will require + # maps the connector name to the accessnode + required_forward_inputs: Dict[str, nd.AccessNode] = {} + + for input_name in sorted(required_gradients): + # we add an einsum for each required gradient + forward_inputs, einsum_str = reverse_einsum_wrt_input(forward_node, input_name) + + einsum_node = donnx.ONNXEinsum(input_name + "_backward", equation=einsum_str) + nstate.add_node(einsum_node) + + # the first input is always the output grad + einsum_node.add_in_connector(f"Inputs__0") + nstate.add_edge(access_output_grad, None, einsum_node, "Inputs__0", + nsdfg.make_array_memlet(result.given_grad_names["Output"])) + + # add the other inputs from forward that we need + for i, forward_input in enumerate(sorted(forward_inputs)): + connector = f"Inputs__{i + 1}" + einsum_node.add_in_connector(connector) + if forward_input not in required_forward_inputs: + required_forward_inputs[forward_input] = create_access_node(forward_input) + + nstate.add_edge(required_forward_inputs[forward_input], None, einsum_node, connector, + nsdfg.make_array_memlet(forward_input)) + + # write out the gradient + butils.forward_in_desc_with_name(forward_node, context, input_name) + result.required_grad_names[input_name] = butils.add_backward_desc_for_connector( + nsdfg, forward_node, context, input_name, True) + memlet = nsdfg.make_array_memlet(result.required_grad_names[input_name]) + # Add a wcr for gradient accumulation + memlet.wcr = "lambda x, y: x + y" + nstate.add_edge(einsum_node, "Output", nstate.add_write(result.required_grad_names[input_name]), None, + memlet) + + result_node = context.backward_state.add_nested_sdfg( + nsdfg, + set(result.given_grad_names.values()).union(required_forward_inputs), + set(result.required_grad_names.values())) + + return result_node, result + + +@autoregister_params(op="Clip", name="default") +class DefaultClipBackward(BackwardImplementation): + """Backward implementation for ONNX Clip operation. + + Computes gradients by zeroing out regions where the input was clipped + and passing through gradients where the input was within bounds. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + result_node, result = butils.add_empty_sdfg_for_node(forward_node, ["input_grad", "output_grad", "input"], + context) + + nstate = result_node.sdfg.add_state() + + min_node = next(context.forward_state.in_edges_by_connector(forward_node, 'min')).src + max_node = next(context.forward_state.in_edges_by_connector(forward_node, 'max')).src + minval = onnx_constant_or_none(context.forward_sdfg, min_node) + maxval = onnx_constant_or_none(context.forward_sdfg, max_node) + + idesc = butils.forward_in_desc_with_name(forward_node, context, "input") + shape = idesc.shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + + input_dtype = idesc.dtype + minstr = f"dace.{input_dtype.to_string()}({minval})" + maxstr = f"dace.{input_dtype.to_string()}({maxval})" + + index_str = f"{', '.join(map_ranges.keys())}" + code = f""" +if __input < {minstr} or __input > {maxstr}: + __input_grad = 0 +else: + __input_grad = __output_grad + """ + nstate.add_mapped_tasklet(forward_node.label + "_backward", + map_ranges=map_ranges, + inputs={ + f"__output_grad": dace.Memlet(f"output_grad[{index_str}]"), + f"__input": dace.Memlet(f"input[{index_str}]"), + }, + code=code, + outputs={f"__input_grad": dace.Memlet(f"input_grad[{index_str}]")}, + external_edges=True) + + return result_node, result + + +@autoregister_params(op="Dropout", name="default") +class DefaultDropoutBackward(BackwardImplementation): + """Backward implementation for ONNX Dropout operation. + + Applies the dropout mask to the output gradients and scales by the keep probability. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + result_node, result = butils.add_empty_sdfg_for_node(forward_node, + ["data_grad", "output_grad", "mask", "ratio"], context) + + nstate = result_node.sdfg.add_state() + + data_desc = butils.forward_in_desc_with_name(forward_node, context, "data") + shape = data_desc.shape + dtype = data_desc.dtype + dtype_str = dtype.to_string() + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + code = f""" +scale = dace.{dtype_str}(1.0) / (1 - __ratio) +__data_grad = __output_grad * __mask * scale + """ + nstate.add_mapped_tasklet(forward_node.label + "_backward", + map_ranges=map_ranges, + inputs={ + "__output_grad": dace.Memlet(f"output_grad[{index_str}]"), + "__mask": dace.Memlet(f"mask[{index_str}]"), + "__ratio": dace.Memlet("ratio[0]") + }, + code=code, + outputs={f"__data_grad": dace.Memlet(f"data_grad[{index_str}]")}, + external_edges=True) + + return result_node, result + + +@autoregister_params(op="Softmax", name="default") +class DefaultSoftmaxBackward(BackwardImplementation): + """Backward implementation for ONNX Softmax operation. + + Computes gradients using the mathematical relationship: + dX = softmax(X) * (dY - sum(dY * softmax(X))) + where dY is the output gradient and dX is the input gradient. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + dim = forward_node.axis + + output_desc = copy.deepcopy(butils.forward_out_desc_with_name(forward_node, context, "output")) + output_desc.transient = False + + sums_shape = list(copy.deepcopy(output_desc.shape)) + sums_shape[dim] = 1 + + # Create new SDFG + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + result = BackwardResult.empty() + + # Given gradients (from output of forward pass) + result.given_grad_names["output"] = "output_grad" + output_grad_desc = copy.deepcopy(output_desc) + nsdfg.add_datadesc("output_grad", output_grad_desc) + + # Required gradient to be computed + input_name = "input" + if "input" not in required_gradients: + # this can happen for example in bert, where the input to softmax is masked + input_name = next(iter(required_gradients)) + + input_grad_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, input_name)) + input_grad_desc.transient = False + input_grad_desc_dtype = input_grad_desc.dtype + result.required_grad_names[input_name] = "input_grad" + nsdfg.add_datadesc("input_grad", input_grad_desc) + + # We need the output of the forward op + nsdfg.add_datadesc("output", output_desc) + + # Intermediate arrays + prod_desc = copy.deepcopy(output_desc) + prod_desc.transient = True + nsdfg.add_datadesc("prod", prod_desc) + + sums_desc = dace.data.Array(input_grad_desc_dtype, sums_shape, transient=True) + nsdfg.add_datadesc("sums", sums_desc) + + sub_term_desc = copy.deepcopy(output_desc) + sub_term_desc.transient = True + nsdfg.add_datadesc("sub_term", sub_term_desc) + + # Add nodes + output_grad_read = nstate.add_read("output_grad") + forward_output_read = nstate.add_read("output") + input_grad_write = nstate.add_write("input_grad") + prod_access = nstate.add_access("prod") + sums_access = nstate.add_access("sums") + sub_term_access = nstate.add_access("sub_term") + + # prod = forward_output * output_grad + mul_node1 = donnx.ONNXMul("mul_prod") + nstate.add_node(mul_node1) + nstate.add_edge(forward_output_read, None, mul_node1, "A", nsdfg.make_array_memlet("output")) + nstate.add_edge(output_grad_read, None, mul_node1, "B", nsdfg.make_array_memlet("output_grad")) + nstate.add_edge(mul_node1, "C", prod_access, None, nsdfg.make_array_memlet("prod")) + + # sums = ReduceSum(prod, axes=[dim], keepdims=1) + reduce_sum_node = donnx.ONNXReduceSum("reduce_sum", keepdims=1, optional={"axes"}) + reduce_sum_node.axes = dim + nstate.add_node(reduce_sum_node) + + # Setup the axes input for the ReduceSum node + axes_name, _ = nsdfg.add_array(name="reduce_sum_axes", shape=[1], dtype=dace.int64, transient=True) + axes_access = nstate.add_access(axes_name) + axes_tasklet = nstate.add_tasklet("init_axes", {}, {"out"}, f"out = {dim};", language=dace.Language.CPP) + nstate.add_edge(axes_tasklet, "out", axes_access, None, dace.Memlet(f"{axes_name}")) + + nstate.add_edge(prod_access, None, reduce_sum_node, "data", nsdfg.make_array_memlet("prod")) + nstate.add_edge(axes_access, None, reduce_sum_node, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(reduce_sum_node, "reduced", sums_access, None, nsdfg.make_array_memlet("sums")) + + # sub_term = forward_output * sums + mul_node2 = donnx.ONNXMul("mul_sub_term") + nstate.add_node(mul_node2) + nstate.add_edge(forward_output_read, None, mul_node2, "A", nsdfg.make_array_memlet("output")) + nstate.add_edge(sums_access, None, mul_node2, "B", nsdfg.make_array_memlet("sums")) + nstate.add_edge(mul_node2, "C", sub_term_access, None, nsdfg.make_array_memlet("sub_term")) + + # input_grad = prod - sub_term + sub_node = donnx.ONNXSub("sub_input_grad") + nstate.add_node(sub_node) + nstate.add_edge(prod_access, None, sub_node, "A", nsdfg.make_array_memlet("prod")) + nstate.add_edge(sub_term_access, None, sub_node, "B", nsdfg.make_array_memlet("sub_term")) + nstate.add_edge(sub_node, "C", input_grad_write, None, nsdfg.make_array_memlet("input_grad")) + + # Create nested SDFG + result_node = context.backward_state.add_nested_sdfg( + nsdfg, + # Inputs to nested SDFG + {"output", "output_grad"}, + # Outputs from nested SDFG + {"input_grad"}) + + butils.connect_output_from_forward(forward_node, result_node, context, "output") + + return result_node, result + + +def _find_map_by_param(sdfg: dace.SDFG, pname: str) -> dace.nodes.MapEntry: + """Find the first map entry node by the given parameter name. + + :param sdfg: The SDFG to search. + :param pname: The parameter name to look for. + :return: The first MapEntry node containing the specified parameter. + :raises StopIteration: If no MapEntry with the parameter is found. + """ + return next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.MapEntry) and pname in n.params) + + +@autoregister_params(op="MaxPool", name="default") +class DefaultMaxPoolBackward(BackwardImplementation): + """Backward implementation for ONNX MaxPool operation. + + Implements gradient computation by routing gradients only to the locations + that achieved the maximum value in the forward pass. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[Union[nd.Node, dace.SDFG], BackwardResult]: + + output_shape = butils.forward_out_desc_with_name(forward_node, context, "Y").shape + + N, C, H, W = output_shape + sty, stx = forward_node.strides + sy, sx = forward_node.kernel_shape + dtype = butils.forward_in_desc_with_name(forward_node, context, "X").dtype + + def maxpool_backward(X, Y_grad, X_grad): + for b, c, ti, tj in dace.map[0:N, 0:C, 0:H, 0:W]: + maxv = np.empty([1], dtype=dtype) + maxi = np.empty([1], dtype=np.int32) + maxj = np.empty([1], dtype=np.int32) + with dace.tasklet: + v >> maxv + v = -9999999 + + # Deterministic argmax + for i, j in dace.map[0:sy, 0:sx] @ dace.ScheduleType.Sequential: + with dace.tasklet: + o << X[b, c, sty * ti + i, stx * tj + j] + vin << maxv + v >> maxv(-1) + ind_i >> maxi(-1) + ind_j >> maxj(-1) + if o > vin: + v = o + ind_i = i + ind_j = j + with dace.tasklet: + igrad << Y_grad[b, c, ti, tj] + ind_i << maxi + ind_j << maxj + ograd >> X_grad(1)[b, c, :, :] + ograd[ind_i, ind_j] = igrad + + result_node, result = butils.backward_program_for_node(maxpool_backward, context, forward_node) + + return result_node, result + + +@autoregister_params(op="LogSoftmax", name="default") +class DefaultLogSoftmaxBackward(BackwardImplementation): + """Backward implementation for ONNX LogSoftmax operation. + + Computes gradients using the mathematical relationship for log-softmax: + dX = dY - exp(Y) * sum(dY) + where Y is the forward output and dY is the output gradient. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + dim = forward_node.axis + output_shape = butils.forward_out_desc_with_name(forward_node, context, "output").shape + output_dtype = butils.forward_out_desc_with_name(forward_node, context, "output").dtype + + sums_shape = list(copy.deepcopy(output_shape)) + sums_shape[dim] = 1 + + def logsoftmax_backward(output, output_grad, input_grad): + exp_output = dace.define_local(output_shape, output_dtype) + donnx.ONNXExp(input=output, output=exp_output) + + grad_output_sum = dace.define_local(sums_shape, output_dtype) + donnx.ONNXReduceSum(data=output_grad, reduced=grad_output_sum, keepdims=1, axes=[dim]) + # let's not use ONNXMul here; not sure how this inplace op is handled by ORT... + exp_output[:] = exp_output * grad_output_sum + donnx.ONNXSub(A=output_grad, B=exp_output, C=input_grad) + + result_node, result = butils.backward_program_for_node(logsoftmax_backward, context, forward_node) + + butils.connect_output_from_forward(forward_node, result_node, context, "output") + return result_node, result + + +@autoregister_params(op="GlobalAveragePool", name="pure") +class PureGlobalAveragePoolingBackward(BackwardImplementation): + """Pure implementation of GlobalAveragePool backward pass. + + Broadcasts the output gradient uniformly across the spatial dimensions + with appropriate scaling by the pool size. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return len(in_desc_with_name(node, state, sdfg, "X").shape) == 4 + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + desc = butils.forward_in_desc_with_name(forward_node, context, "X") + N, C, H, W = desc.shape + dtype = desc.dtype + + inv = 1.0 / (H * W) + + def bwd(X_grad, Y_grad): + for n, c, h, w in dace.map[0:N, 0:C, 0:H, 0:W]: + with dace.tasklet: + y_grad << Y_grad[n, c] + x_grad >> X_grad[n, c, h, w] + x_grad = y_grad * dtype(inv) + + return butils.backward_program_for_node(bwd, context, forward_node) + + +@autoregister_params(op="Transpose", name="default") +class DefaultTransposeBackward(BackwardImplementation): + """Backward implementation for ONNX Transpose operation. + + The gradient of transpose is another transpose with inverted permutation. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + inv_perm = tuple(np.argsort(forward_node.perm)) + + node = donnx.ONNXTranspose(forward_node.name + "_backward", perm=inv_perm) + context.backward_state.add_node(node) + + result = BackwardResult.empty() + result.given_grad_names["transposed"] = "data" + result.required_grad_names["data"] = "transposed" + + return node, result + + +@autoregister_params(op="Where", name="default") +class WhereBackward(BackwardImplementation): + """Backward implementation for ONNX Where operation. + + Routes gradients based on the condition: gradients flow to X where condition is True, + and to Y where condition is False. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + # condition, X, Y -> Output + # Get condition descriptor for shape information + _ = butils.forward_in_desc_with_name(forward_node, context, "condition") + + # NOTE: We cannot use ONNX ops for further potential lowering + # transformations because ONNXMul does not support boolean inputs. + # notcondition = dace.define_local(condition_shape, condition_dtype) + # donnx.ONNXMul(A=condition, B=output_grad, C=X_grad) + # donnx.ONNXNot(X=condition, Y=notcondition) + # donnx.ONNXMul(A=notcondition, B=output_grad, C=Y_grad) + + if 'X' in required_gradients and 'Y' not in required_gradients: + + def where_backward(condition, output_grad, X_grad): + X_grad[:] = condition * output_grad + elif 'Y' in required_gradients and 'X' not in required_gradients: + + def where_backward(condition, output_grad, Y_grad): + Y_grad[:] = ~condition * output_grad + elif 'X' in required_gradients and 'Y' in required_gradients: + + def where_backward(condition, output_grad, X_grad, Y_grad): + X_grad[:] = condition * output_grad + Y_grad[:] = ~condition * output_grad + + result_node, result = butils.backward_program_for_node(where_backward, context, forward_node) + + return result_node, result + + +@autoregister_params(op="LayerNormalization", name="default") +class DefaultLayerNormalizationBackward(BackwardImplementation): + """Backward implementation for ONNX LayerNormalization operation. + + Computes gradients for input, scale, and bias parameters using the + mathematical formulation of layer normalization backward pass. + """ + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + # Create new SDFG + nsdfg = dace.SDFG(forward_node.label + "_backward") + nstate = nsdfg.add_state() + + # Get input/output descriptors + X_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "X")) + Scale_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "Scale")) + Y_grad_desc = copy.deepcopy(butils.forward_out_desc_with_name(forward_node, context, "Y")) + X_desc.transient = False + Y_grad_desc.transient = False + Scale_desc.transient = False + + result = BackwardResult.empty() + # setup gradient arrays + result.given_grad_names["Y"] = "Y_grad" + if "X" in required_gradients: + result.required_grad_names["X"] = "X_grad" + if "Scale" in required_gradients: + result.required_grad_names["Scale"] = "Scale_grad" + if "B" in required_gradients: + result.required_grad_names["B"] = "B_grad" + + # Add data descriptors to SDFG + nsdfg.add_datadesc("X", X_desc) + nsdfg.add_datadesc("Scale", Scale_desc) + nsdfg.add_datadesc("Y_grad", Y_grad_desc) + + if "X" in required_gradients: + X_grad_desc = copy.deepcopy(X_desc) + nsdfg.add_datadesc("X_grad", X_grad_desc) + if "Scale" in required_gradients: + Scale_grad_desc = copy.deepcopy(Scale_desc) + nsdfg.add_datadesc("Scale_grad", Scale_grad_desc) + if "B" in required_gradients: + B_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, "B")) + B_desc.transient = False + B_grad_desc = copy.deepcopy(B_desc) + nsdfg.add_datadesc("B_grad", B_grad_desc) + # Add B to SDFG inputs when needed + nsdfg.add_datadesc("B", B_desc) + + # Get axis and epsilon + axis = forward_node.axis if hasattr(forward_node, 'axis') else -1 + epsilon = forward_node.epsilon if hasattr(forward_node, 'epsilon') else 1e-5 + + rank = len(X_desc.shape) + if axis < 0: + axis = rank + axis + reduction_axes = list(range(axis, rank)) + leading_non_normalized_axes = list(range(axis)) + # Calculate normalization size for reference (currently unused) + _ = float(np.prod([X_desc.shape[i] for i in range(axis, rank)])) + + # Create axes tensor for reduction + axes_name = "reduction_axes" + axes_desc = dace.data.Array(dace.int64, [len(reduction_axes)]) + axes_desc.transient = True # Make it transient since it's internal + nsdfg.add_datadesc(axes_name, axes_desc) + axes_access = nstate.add_access(axes_name) + + # Initialize reduction axes as a constant array + axes_tasklet = nstate.add_tasklet(name="init_axes", + inputs={}, + outputs={"out": dace.pointer(dace.int64)}, + code=f"\n".join([f"out[{i}] = {0};" for i, _ in enumerate(reduction_axes)]), + language=dace.Language.CPP) + nstate.add_edge(axes_tasklet, "out", axes_access, None, dace.Memlet(f"{axes_name}[0:{len(reduction_axes)}]")) + + # Create mean descriptor with reduced shape + mean_shape = list(X_desc.shape) + for i in reduction_axes: + mean_shape[i] = 1 + mean_desc = dace.data.Array(X_desc.dtype, mean_shape) + mean_desc.transient = True + mean_name = "mean" + nsdfg.add_datadesc(mean_name, mean_desc) + + mean_op = donnx.ONNXReduceMean("mean_op", keepdims=1, optional={"axes"}) + mean_op.axes = reduction_axes + nstate.add_node(mean_op) + nstate.add_edge(nstate.add_read("X"), None, mean_op, "data", nsdfg.make_array_memlet("X")) + nstate.add_edge(axes_access, None, mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + mean_access = nstate.add_access("mean") + nstate.add_edge(mean_op, "reduced", mean_access, None, nsdfg.make_array_memlet("mean")) + + # Recompute variance + diff_shape = list(X_desc.shape) + diff_desc = dace.data.Array(X_desc.dtype, diff_shape) + diff_desc.transient = True + diff_name = "diff" + nsdfg.add_datadesc(diff_name, diff_desc) + + diff_op = donnx.ONNXSub("diff_op") + nstate.add_node(diff_op) + nstate.add_edge(nstate.add_read("X"), None, diff_op, "A", nsdfg.make_array_memlet("X")) + nstate.add_edge(mean_access, None, diff_op, "B", nsdfg.make_array_memlet("mean")) + diff_access = nstate.add_access("diff") + nstate.add_edge(diff_op, "C", diff_access, None, nsdfg.make_array_memlet("diff")) + + # Create squared difference descriptor + sq_diff_shape = list(X_desc.shape) + sq_diff_desc = dace.data.Array(X_desc.dtype, sq_diff_shape) + sq_diff_desc.transient = True + sq_diff_name = "sq_diff" + nsdfg.add_datadesc(sq_diff_name, sq_diff_desc) + + sq_diff_op = donnx.ONNXMul("sq_diff_op") + nstate.add_node(sq_diff_op) + nstate.add_edge(diff_access, None, sq_diff_op, "A", nsdfg.make_array_memlet("diff")) + nstate.add_edge(diff_access, None, sq_diff_op, "B", nsdfg.make_array_memlet("diff")) + sq_diff_access = nstate.add_access("sq_diff") + nstate.add_edge(sq_diff_op, "C", sq_diff_access, None, nsdfg.make_array_memlet("sq_diff")) + + # Create variance descriptor with reduced shape + variance_shape = list(X_desc.shape) + for i in reduction_axes: + variance_shape[i] = 1 + variance_desc = dace.data.Array(X_desc.dtype, variance_shape) + variance_desc.transient = True + variance_name = "variance" + nsdfg.add_datadesc(variance_name, variance_desc) + + variance_op = donnx.ONNXReduceMean("variance_op", keepdims=1, optional={"axes"}) + variance_op.axes = reduction_axes + nstate.add_node(variance_op) + nstate.add_edge(sq_diff_access, None, variance_op, "data", nsdfg.make_array_memlet("sq_diff")) + nstate.add_edge(axes_access, None, variance_op, "axes", nsdfg.make_array_memlet(axes_name)) + variance_access = nstate.add_access("variance") + nstate.add_edge(variance_op, "reduced", variance_access, None, nsdfg.make_array_memlet("variance")) + + # Add epsilon to variance + epsilon_name, _ = nsdfg.add_scalar("epsilon", X_desc.dtype, transient=True) + epsilon_tasklet = nstate.add_tasklet( + "make_epsilon", + {}, + {"out"}, + f"out = {epsilon};", + language=dace.Language.CPP, + ) + epsilon_write = nstate.add_write(epsilon_name) + nstate.add_edge(epsilon_tasklet, "out", epsilon_write, None, dace.Memlet(f"{epsilon_name}[0]")) + + # Create variance_eps descriptor + variance_eps_desc = dace.data.Array(X_desc.dtype, variance_shape) + variance_eps_desc.transient = True + variance_eps_name = "variance_eps" + nsdfg.add_datadesc(variance_eps_name, variance_eps_desc) + + variance_eps_op = donnx.ONNXAdd("variance_eps_op") + nstate.add_node(variance_eps_op) + nstate.add_edge(variance_access, None, variance_eps_op, "A", nsdfg.make_array_memlet("variance")) + nstate.add_edge(epsilon_write, None, variance_eps_op, "B", nsdfg.make_array_memlet(epsilon_name)) + variance_eps_access = nstate.add_access("variance_eps") + nstate.add_edge(variance_eps_op, "C", variance_eps_access, None, nsdfg.make_array_memlet("variance_eps")) + + # Create std_dev descriptor + std_dev_desc = dace.data.Array(X_desc.dtype, variance_shape) + std_dev_desc.transient = True + std_dev_name = "std_dev" + nsdfg.add_datadesc(std_dev_name, std_dev_desc) + + std_dev_op = donnx.ONNXSqrt("std_dev_op") + nstate.add_node(std_dev_op) + nstate.add_edge(variance_eps_access, None, std_dev_op, "X", nsdfg.make_array_memlet("variance_eps")) + std_dev_access = nstate.add_access("std_dev") + nstate.add_edge(std_dev_op, "Y", std_dev_access, None, nsdfg.make_array_memlet("std_dev")) + + # Create inv_std_dev descriptor + one_name, _ = nsdfg.add_scalar("one", X_desc.dtype, transient=True) + one_tasklet = nstate.add_tasklet("make_one", {}, {"out"}, "out = 1.0;", language=dace.Language.CPP) + one_write = nstate.add_write(one_name) + nstate.add_edge(one_tasklet, "out", one_write, None, dace.Memlet(f"{one_name}[0]")) + + inv_std_dev_desc = dace.data.Array(X_desc.dtype, variance_shape) + inv_std_dev_desc.transient = True + inv_std_dev_name = "inv_std_dev" + nsdfg.add_datadesc(inv_std_dev_name, inv_std_dev_desc) + + inv_std_dev_op = donnx.ONNXDiv("inv_std_dev_op") + nstate.add_node(inv_std_dev_op) + nstate.add_edge(one_write, None, inv_std_dev_op, "A", nsdfg.make_array_memlet(one_name)) + nstate.add_edge(std_dev_access, None, inv_std_dev_op, "B", nsdfg.make_array_memlet("std_dev")) + inv_std_dev_access = nstate.add_access("inv_std_dev") + nstate.add_edge(inv_std_dev_op, "C", inv_std_dev_access, None, nsdfg.make_array_memlet("inv_std_dev")) + + # Create x_hat descriptor (normalized input) + x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + x_hat_desc.transient = True + x_hat_name = "x_hat" + nsdfg.add_datadesc(x_hat_name, x_hat_desc) + + x_hat_op = donnx.ONNXMul("x_hat_op") + nstate.add_node(x_hat_op) + nstate.add_edge(diff_access, None, x_hat_op, "A", nsdfg.make_array_memlet("diff")) + nstate.add_edge(inv_std_dev_access, None, x_hat_op, "B", nsdfg.make_array_memlet("inv_std_dev")) + x_hat_access = nstate.add_access("x_hat") + nstate.add_edge(x_hat_op, "C", x_hat_access, None, nsdfg.make_array_memlet("x_hat")) + + # Compute bias gradient if needed + if "B" in required_gradients: + b_grad_op = donnx.ONNXReduceSum("b_grad_op", keepdims=0, optional={"axes"}) + # This reduction will sum over the leading non-normalized axes + b_grad_op.axes = leading_non_normalized_axes + nstate.add_node(b_grad_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, b_grad_op, "data", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(axes_access, None, b_grad_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(b_grad_op, "reduced", nstate.add_write("B_grad"), None, nsdfg.make_array_memlet("B_grad")) + + # Compute scale gradient if needed + if "Scale" in required_gradients: + dY_x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dY_x_hat_desc.transient = True + dY_x_hat_name = "dY_x_hat" + nsdfg.add_datadesc(dY_x_hat_name, dY_x_hat_desc) + + dY_x_hat_op = donnx.ONNXMul("dY_x_hat_op") + nstate.add_node(dY_x_hat_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, dY_x_hat_op, "A", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(x_hat_access, None, dY_x_hat_op, "B", nsdfg.make_array_memlet("x_hat")) + dY_x_hat_access = nstate.add_access("dY_x_hat") + nstate.add_edge(dY_x_hat_op, "C", dY_x_hat_access, None, nsdfg.make_array_memlet("dY_x_hat")) + + scale_grad_op = donnx.ONNXReduceSum("scale_grad_op", keepdims=0, optional={"axes"}) + scale_grad_op.axes = leading_non_normalized_axes + nstate.add_node(scale_grad_op) + nstate.add_edge(dY_x_hat_access, None, scale_grad_op, "data", nsdfg.make_array_memlet("dY_x_hat")) + nstate.add_edge(axes_access, None, scale_grad_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(scale_grad_op, "reduced", nstate.add_write("Scale_grad"), None, + nsdfg.make_array_memlet("Scale_grad")) + + # Compute X gradient if needed + if "X" in required_gradients: + # Create dX_hat descriptor (gradient with respect to normalized input) + dX_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_desc.transient = True + dX_hat_name = "dX_hat" + nsdfg.add_datadesc(dX_hat_name, dX_hat_desc) + + dX_hat_op = donnx.ONNXMul("dX_hat_op") + nstate.add_node(dX_hat_op) + nstate.add_edge(nstate.add_read("Y_grad"), None, dX_hat_op, "A", nsdfg.make_array_memlet("Y_grad")) + nstate.add_edge(nstate.add_read("Scale"), None, dX_hat_op, "B", nsdfg.make_array_memlet("Scale")) + dX_hat_access = nstate.add_access("dX_hat") + nstate.add_edge(dX_hat_op, "C", dX_hat_access, None, nsdfg.make_array_memlet("dX_hat")) + + # Compute mean of dX_hat over reduction axes + dX_hat_mean_desc = dace.data.Array(X_desc.dtype, variance_shape) + dX_hat_mean_desc.transient = True + dX_hat_mean_name = "dX_hat_mean" + nsdfg.add_datadesc(dX_hat_mean_name, dX_hat_mean_desc) + + dX_hat_mean_op = donnx.ONNXReduceMean("dX_hat_mean_op", keepdims=1, optional={"axes"}) + dX_hat_mean_op.axes = reduction_axes + nstate.add_node(dX_hat_mean_op) + nstate.add_edge(dX_hat_access, None, dX_hat_mean_op, "data", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(axes_access, None, dX_hat_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + dX_hat_mean_access = nstate.add_access("dX_hat_mean") + nstate.add_edge(dX_hat_mean_op, "reduced", dX_hat_mean_access, None, nsdfg.make_array_memlet("dX_hat_mean")) + + # Compute dX_hat * x_hat + dX_hat_x_hat_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_x_hat_desc.transient = True + dX_hat_x_hat_name = "dX_hat_x_hat" + nsdfg.add_datadesc(dX_hat_x_hat_name, dX_hat_x_hat_desc) + + dX_hat_x_hat_op = donnx.ONNXMul("dX_hat_x_hat_op") + nstate.add_node(dX_hat_x_hat_op) + nstate.add_edge(dX_hat_access, None, dX_hat_x_hat_op, "A", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(x_hat_access, None, dX_hat_x_hat_op, "B", nsdfg.make_array_memlet("x_hat")) + dX_hat_x_hat_access = nstate.add_access("dX_hat_x_hat") + nstate.add_edge(dX_hat_x_hat_op, "C", dX_hat_x_hat_access, None, nsdfg.make_array_memlet("dX_hat_x_hat")) + + # Compute mean of dX_hat * x_hat over reduction axes + dX_hat_x_hat_mean_desc = dace.data.Array(X_desc.dtype, variance_shape) + dX_hat_x_hat_mean_desc.transient = True + dX_hat_x_hat_mean_name = "dX_hat_x_hat_mean" + nsdfg.add_datadesc(dX_hat_x_hat_mean_name, dX_hat_x_hat_mean_desc) + + dX_hat_x_hat_mean_op = donnx.ONNXReduceMean("dX_hat_x_hat_mean_op", keepdims=1, optional={"axes"}) + dX_hat_x_hat_mean_op.axes = reduction_axes + nstate.add_node(dX_hat_x_hat_mean_op) + nstate.add_edge(dX_hat_x_hat_access, None, dX_hat_x_hat_mean_op, "data", + nsdfg.make_array_memlet("dX_hat_x_hat")) + nstate.add_edge(axes_access, None, dX_hat_x_hat_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + dX_hat_x_hat_mean_access = nstate.add_access("dX_hat_x_hat_mean") + nstate.add_edge(dX_hat_x_hat_mean_op, "reduced", dX_hat_x_hat_mean_access, None, + nsdfg.make_array_memlet("dX_hat_x_hat_mean")) + + # Compute x_hat * mean(dX_hat * x_hat) + x_hat_dX_hat_x_hat_mean_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + x_hat_dX_hat_x_hat_mean_desc.transient = True + x_hat_dX_hat_x_hat_mean_name = "x_hat_dX_hat_x_hat_mean" + nsdfg.add_datadesc(x_hat_dX_hat_x_hat_mean_name, x_hat_dX_hat_x_hat_mean_desc) + + x_hat_dX_hat_x_hat_mean_op = donnx.ONNXMul("x_hat_dX_hat_x_hat_mean_op") + nstate.add_node(x_hat_dX_hat_x_hat_mean_op) + nstate.add_edge(x_hat_access, None, x_hat_dX_hat_x_hat_mean_op, "A", nsdfg.make_array_memlet("x_hat")) + nstate.add_edge(dX_hat_x_hat_mean_access, None, x_hat_dX_hat_x_hat_mean_op, "B", + nsdfg.make_array_memlet("dX_hat_x_hat_mean")) + x_hat_dX_hat_x_hat_mean_access = nstate.add_access("x_hat_dX_hat_x_hat_mean") + nstate.add_edge(x_hat_dX_hat_x_hat_mean_op, "C", x_hat_dX_hat_x_hat_mean_access, None, + nsdfg.make_array_memlet("x_hat_dX_hat_x_hat_mean")) + + # Compute dX_hat - mean(dX_hat) - x_hat * mean(dX_hat * x_hat) + dX_hat_minus_mean_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_minus_mean_desc.transient = True + dX_hat_minus_mean_name = "dX_hat_minus_mean" + nsdfg.add_datadesc(dX_hat_minus_mean_name, dX_hat_minus_mean_desc) + + dX_hat_minus_mean_op = donnx.ONNXSub("dX_hat_minus_mean_op") + nstate.add_node(dX_hat_minus_mean_op) + nstate.add_edge(dX_hat_access, None, dX_hat_minus_mean_op, "A", nsdfg.make_array_memlet("dX_hat")) + nstate.add_edge(dX_hat_mean_access, None, dX_hat_minus_mean_op, "B", nsdfg.make_array_memlet("dX_hat_mean")) + dX_hat_minus_mean_access = nstate.add_access("dX_hat_minus_mean") + nstate.add_edge(dX_hat_minus_mean_op, "C", dX_hat_minus_mean_access, None, + nsdfg.make_array_memlet("dX_hat_minus_mean")) + + # Final subtraction + dX_hat_final_desc = dace.data.Array(X_desc.dtype, X_desc.shape) + dX_hat_final_desc.transient = True + dX_hat_final_name = "dX_hat_final" + nsdfg.add_datadesc(dX_hat_final_name, dX_hat_final_desc) + + dX_hat_final_op = donnx.ONNXSub("dX_hat_final_op") + nstate.add_node(dX_hat_final_op) + nstate.add_edge(dX_hat_minus_mean_access, None, dX_hat_final_op, "A", + nsdfg.make_array_memlet("dX_hat_minus_mean")) + nstate.add_edge(x_hat_dX_hat_x_hat_mean_access, None, dX_hat_final_op, "B", + nsdfg.make_array_memlet("x_hat_dX_hat_x_hat_mean")) + dX_hat_final_access = nstate.add_access("dX_hat_final") + nstate.add_edge(dX_hat_final_op, "C", dX_hat_final_access, None, nsdfg.make_array_memlet("dX_hat_final")) + + # Multiply by inv_std_dev to get final X gradient + x_grad_op = donnx.ONNXMul("x_grad_op") + nstate.add_node(x_grad_op) + nstate.add_edge(inv_std_dev_access, None, x_grad_op, "A", nsdfg.make_array_memlet("inv_std_dev")) + nstate.add_edge(dX_hat_final_access, None, x_grad_op, "B", nsdfg.make_array_memlet("dX_hat_final")) + nstate.add_edge(x_grad_op, "C", nstate.add_write("X_grad"), None, nsdfg.make_array_memlet("X_grad")) + + # Set up inputs for nested SDFG + inputs = {"X", "Scale", "Y_grad"} + if "B" in required_gradients: + inputs.add("B") + + outputs = set(result.required_grad_names.values()) + bwd_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + return bwd_node, result + + +@autoregister_params(op="ReduceSum", name="default") +class DefaultReduceSumBackward(BackwardImplementation): + """Backward implementation for ONNX ReduceSum operation. + + The backward pass of a reduction is a broadcast of the output gradient + to match the input shape. Handles both keepdims=True and keepdims=False cases. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + return True + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + # The backward pass of a reduction is a broadcast. + # We use ONNXExpand to perform the broadcast. + # If keepdims=False, we first need to unsqueeze the gradient. + + input_desc = butils.forward_in_desc_with_name(forward_node, context, "data") + output_desc = butils.forward_out_desc_with_name(forward_node, context, "reduced") + + nsdfg = dace.SDFG(f"{forward_node.label}_backward") + nstate = nsdfg.add_state() + + result = BackwardResult.empty() + result.given_grad_names["reduced"] = "reduced_grad" + result.required_grad_names["data"] = "data_grad" + + reduced_grad_desc = copy.deepcopy(output_desc) + reduced_grad_desc.transient = False + nsdfg.add_datadesc("reduced_grad", reduced_grad_desc) + + data_grad_desc_tmp = copy.deepcopy(input_desc) + data_grad_desc_tmp.transient = True + nsdfg.add_datadesc("data_grad_tmp", data_grad_desc_tmp) + + data_grad_desc = copy.deepcopy(input_desc) + data_grad_desc.transient = False + nsdfg.add_datadesc("data_grad", data_grad_desc) + + grad_to_expand = "reduced_grad" + read_grad_to_expand = nstate.add_read(grad_to_expand) + + keepdims = getattr(forward_node, 'keepdims', 1) + + if not keepdims: + # When keepdims is False, the rank of the output is reduced. We need to + # unsqueeze the gradient to match the input rank before broadcasting. + + # Deduce reduced axes by comparing input and output shapes. + in_shape = input_desc.shape + out_shape = reduced_grad_desc.shape + unsqueezed_shape = [] + axes = [] + if len(in_shape) < len(out_shape): + raise ValueError(f"Input shape {in_shape} has fewer dimensions than output shape {out_shape}. " + f"This is unexpected for a ReduceSum operation.") + if len(in_shape) > len(out_shape): + # This assumes that non-reduced dimensions are preserved in order. + out_shape_idx = 0 + for i, dim in enumerate(in_shape): + if out_shape_idx < len(out_shape) and dim == out_shape[out_shape_idx]: + out_shape_idx += 1 + unsqueezed_shape.append(dim) + else: + axes.append(i) + unsqueezed_shape.append(1) + + # If shapes are equal, it's a no-op reduction and axes is empty. + if (not axes) != (len(in_shape) == len(out_shape)): + raise ValueError(f"Inconsistent state: axes={axes}, input_shape={in_shape}, output_shape={out_shape}. " + f"For equal shapes, axes should be empty.") + + if 'axes' in forward_node.in_connectors: + # The axes are a dynamic input to the forward node. Pass them to the backward node. + axes_desc = butils.forward_in_desc_with_name(forward_node, context, "axes") + axes_desc_copy = copy.deepcopy(axes_desc) + axes_desc_copy.transient = False + nsdfg.add_datadesc("axes", axes_desc_copy) + axes_access = nstate.add_read("axes") + elif axes: + # Create a constant array for the axes to be passed to Unsqueeze + axes_name_in_bwd, axes_desc_bwd = nsdfg.add_array(f"axes_for_unsqueeze_{forward_node.name}", + [len(axes)], + dace.int64, + transient=True) + axes_tasklet = nstate.add_tasklet( + 'init_axes', + {}, + {'out'}, + '\n'.join([f'out[{i}] = {v};' for i, v in enumerate(axes)]), + language=dace.Language.CPP, + ) + axes_access = nstate.add_access(axes_name_in_bwd) + nstate.add_edge(axes_tasklet, 'out', axes_access, None, + dace.Memlet.from_array(axes_name_in_bwd, axes_desc_bwd)) + + unsqueezed_desc = dace.data.Array(dtype=reduced_grad_desc.dtype, shape=unsqueezed_shape, transient=True) + nsdfg.add_datadesc("unsqueezed_grad", unsqueezed_desc) + + unsqueeze_op = donnx.ONNXUnsqueeze("unsqueeze_grad") + nstate.add_node(unsqueeze_op) + + nstate.add_edge(read_grad_to_expand, None, unsqueeze_op, "data", nsdfg.make_array_memlet("reduced_grad")) + nstate.add_edge(axes_access, None, unsqueeze_op, "axes", + dace.Memlet(data=axes_access.data, subset=f'0:{axes_access.desc(nsdfg).shape[0]}')) + + grad_to_expand = "unsqueezed_grad" + read_grad_to_expand = nstate.add_access(grad_to_expand) + nstate.add_edge(unsqueeze_op, "expanded", read_grad_to_expand, None, + nsdfg.make_array_memlet("unsqueezed_grad")) + + # Create shape tensor for ONNXExpand + shape_name, shape_desc = nsdfg.add_array("shape_for_expand", [len(input_desc.shape)], + dace.int64, + transient=True) + shape_tasklet = nstate.add_tasklet("init_shape", {}, {"out"}, + '\n'.join([f"out[{i}] = {s};" for i, s in enumerate(input_desc.shape)])) + shape_access = nstate.add_access(shape_name) + nstate.add_edge(shape_tasklet, "out", shape_access, None, dace.Memlet.from_array(shape_name, shape_desc)) + + expand_op = donnx.ONNXExpand("expand_grad") + nstate.add_node(expand_op) + write_data_grad_tmp = nstate.add_write("data_grad_tmp") + + nstate.add_edge(read_grad_to_expand, None, expand_op, "input", nsdfg.make_array_memlet(grad_to_expand)) + nstate.add_edge(shape_access, None, expand_op, "shape", nsdfg.make_array_memlet(shape_name)) + nstate.add_edge(expand_op, "output", write_data_grad_tmp, None, nsdfg.make_array_memlet("data_grad_tmp")) + + # We add an additional write from data_grad_tmp to data_grad + # This is necessary to accumulate gradients in the backward pass. + finale_memlet = nsdfg.make_array_memlet("data_grad") + finale_memlet.wcr = "lambda x, y: x + y" + write_data_grad = nstate.add_write("data_grad") + nstate.add_edge(write_data_grad_tmp, None, write_data_grad, None, finale_memlet) + + inputs = {"reduced_grad"} + if not keepdims and 'axes' in forward_node.in_connectors: + inputs.add("axes") + + result_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, {"data_grad"}) + + return result_node, result diff --git a/dace/autodiff/implementations/pytorch_ops.py b/dace/autodiff/implementations/pytorch_ops.py new file mode 100644 index 0000000000..9f50edbb9e --- /dev/null +++ b/dace/autodiff/implementations/pytorch_ops.py @@ -0,0 +1,128 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import copy +import itertools +from typing import List, Optional, Tuple + +import dace +import dace.libraries.torch +from dace.registry import autoregister_params +from dace import nodes as nd + +from dace.libraries.onnx.converters import clean_onnx_name + +import dace.autodiff.utils as butils +from dace.autodiff.base_abc import BackwardImplementation, BackwardContext, BackwardResult +from dace.sdfg.utils import in_desc_with_name + + +@autoregister_params(op="Conv", name="PyTorch-dwise") +class PyTorchConvBackward(BackwardImplementation): + """Depthwise convolution backward implementation using PyTorch. + + This implementation leverages PyTorch's optimized CUDA kernels for + depthwise convolution backward pass computation. + """ + + @staticmethod + def backward_can_be_applied(node: nd.Node, state: dace.SDFGState, sdfg: dace.SDFG) -> bool: + X_desc = in_desc_with_name(node, state, sdfg, "X") + return len(X_desc.shape) == 4 + + @staticmethod + def backward(forward_node: nd.Node, context: BackwardContext, given_gradients: List[Optional[str]], + required_gradients: List[Optional[str]]) -> Tuple[nd.Node, BackwardResult]: + + nsdfg = dace.SDFG(forward_node.label + "_backward") + X_desc = butils.forward_in_desc_with_name(forward_node, context, "X") + W_desc = butils.forward_in_desc_with_name(forward_node, context, "W") + + T = X_desc.dtype + if str(T) == 'float': + pytorch_dtype = 'kFloat' + elif str(T) == 'double': + pytorch_dtype = 'kDouble' + else: + raise ValueError(f"PyTorch backward conv expansion supports only float and double tensors, got {str(T)}. " + f"Supported types: float, double") + + # setup gradient arrays + result = BackwardResult.empty() + required_grads = set(required_gradients) + for r in sorted(required_grads): + result.required_grad_names[r] = butils.add_backward_desc_for_connector(nsdfg, + forward_node, + context, + r, + input=True) + result.given_grad_names["Y"] = butils.add_backward_desc_for_connector(nsdfg, + forward_node, + context, + "Y", + input=False) + + # setup non-gradient arrays + required_forward_inputs = ["W", "X"] + for i in sorted(required_forward_inputs): + new_desc = copy.deepcopy(butils.forward_in_desc_with_name(forward_node, context, i)) + new_desc.transient = False + nsdfg.add_datadesc(i, new_desc) + + # setup state + nstate = nsdfg.add_state() + unique_id = "{}_{}_{}_{}_bwd".format(clean_onnx_name(forward_node.name), context.forward_sdfg.sdfg_id, + context.forward_sdfg.node_id(context.forward_state), + context.forward_state.node_id(forward_node)) + + init_code = "" + finalize_code = "" + code_global = """ + #include + #include + """ + tasklet_inputs = {f"_{i}": dace.pointer(T) for i in itertools.chain(["dY"], sorted(required_forward_inputs))} + tasklet_outputs = {f"_d{i}": dace.pointer(T) for i in itertools.chain(sorted(required_gradients))} + + tasklet_code = f""" + std::vector x_shape = {{ {", ".join(map(str, X_desc.shape))} }}; + std::vector x_strides = {{ {", ".join(map(str, X_desc.strides))} }}; + std::vector w_shape = {{ {", ".join(map(str, W_desc.shape))} }}; + std::vector w_strides = {{ {", ".join(map(str, W_desc.strides))} }}; + at::Tensor x = at::from_blob(_X, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor w = at::from_blob(_W, w_shape, w_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dy = at::from_blob(_dY, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dw = at::from_blob(_dW, w_shape, w_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + at::Tensor dx = at::from_blob(_dX, x_shape, x_strides, [](void*){{}}, at::TensorOptions().device(at::kCUDA).dtype(at::{pytorch_dtype}).requires_grad(false)); + + std::vector kernel_shape = {{ {", ".join(map(str, forward_node.kernel_shape))} }}; + std::vector conv_strides = {{ {", ".join(map(str, forward_node.strides))} }}; + std::vector padding = {{ {", ".join(map(str, forward_node.pads[::2]))} }}; + std::vector dilation = {{ {", ".join(map(str, forward_node.dilations))} }}; + + at::thnn_conv_depthwise2d_backward_out(dx, dw, dy, x, w, kernel_shape, conv_strides, padding, dilation); + """ + + tasklet = nstate.add_tasklet(name=unique_id, + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=tasklet_code, + language=dace.dtypes.Language.CPP, + code_global=code_global, + code_init=init_code, + code_exit=finalize_code) + tasklet.environments = {dace.libraries.torch.environments.PyTorch.full_class_path()} + + nstate.add_edge(nstate.add_read(result.given_grad_names["Y"]), None, tasklet, f"_dY", + nsdfg.make_array_memlet((result.given_grad_names["Y"]))) + for name in sorted(required_forward_inputs): + nstate.add_edge(nstate.add_read(name), None, tasklet, f"_{name}", nsdfg.make_array_memlet(name)) + + for name in sorted(required_gradients): + arr_name = result.required_grad_names[name] + nstate.add_edge(tasklet, f"_d{name}", nstate.add_write(arr_name), None, nsdfg.make_array_memlet(arr_name)) + + inputs = {result.given_grad_names["Y"]}.union(required_forward_inputs) + outputs = {result.required_grad_names[n] for n in sorted(required_gradients)} + node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + + return node, result diff --git a/dace/autodiff/library/__init__.py b/dace/autodiff/library/__init__.py new file mode 100644 index 0000000000..2a56e34067 --- /dev/null +++ b/dace/autodiff/library/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Library Integration for Automatic Differentiation. + +This package provides integration between DaCe's autodiff system and various +libraries and frontends. It enables differentiation of code that uses +library operations and provides hooks for frontend-specific optimizations. +""" + +import dace.library + +from . import library + +# PyTorch integrations are optional +try: + from . import torch_integration + from dace.frontend.python.replacements import torch_autodiff + TORCH_INTEGRATION_AVAILABLE = True +except ImportError: + torch_integration = None + torch_autodiff = None + TORCH_INTEGRATION_AVAILABLE = False + +dace.library.register_library(__name__, "autodiff") + +__all__ = [ + "library", +] + +if TORCH_INTEGRATION_AVAILABLE: + __all__.extend(["torch_integration", "torch_autodiff"]) diff --git a/dace/autodiff/library/library.py b/dace/autodiff/library/library.py new file mode 100644 index 0000000000..b5cc0e5d97 --- /dev/null +++ b/dace/autodiff/library/library.py @@ -0,0 +1,286 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Dace library for autodiff + +Includes the BackwardPass library node, and the replacements for the python frontend +""" +from typing import Dict, Set, Optional +import copy + +import dace +import dace.library +from dace import data, properties +from dace.transformation import transformation as pm +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from dace.autodiff import backward_pass_generator as engine, analysis as autodiff_analysis +from dace.autodiff.utils import init_grad +from dace.sdfg.utils import in_edge_with_name +from dace.transformation.passes.analysis import AccessSets + + +@properties.make_properties +class ParameterArray(data.Array): + """ + An array for which a gradient can be computed. + """ + # since this can be None, this is not a DataProperty + gradient = properties.Property(dtype=str, desc="The corresponding gradient buffer", default=None, allow_none=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __repr__(self): + return "Parameter" + data.Array.__repr__(self) + + def add_gradient_buffer(self, sdfg: SDFG, name: str) -> str: + """ + Find or create a gradient buffer for the parameter in the given SDFG. + + :param sdfg: the SDFG containing the parameter + :param name: the name of the parameter + :return: the name of the gradient buffer + """ + + if self.gradient: + return self.gradient + + # First, check if this array already has a gradient buffer in a nested + # SDFG. This happens, for example when pytorch modules are used in the + # frontend. In that case: + # 1. the parser assembles the closure of the module, which adds + # descriptors for all the parameters and their gradients (if they + # are required). + # 2. A nested sdfg is added for the module, with those array names. + # 3. The DaceProgram will then pass these arrays in when the + # DaceProgram is called, using the names from the closure that + # match the names from the NestedSDFG + # 4. When parsing the backward nodes, we want the gradient buffers in + # the closure to match the gradient buffers that we pass in. Thus, + # we need to make sure that we use the same name as the NestedSDFG + # + # Note that we do not currently do any nesting beyond this level, + # because nested modules are converted to one SDFG. + + cands = set() + for state in sdfg.nodes(): + for node in state.nodes(): + if not isinstance(node, nodes.NestedSDFG): + continue + + nested_names = set() + + for edge in state.in_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + for edge in state.out_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + + for name in nested_names: + nested_desc = node.sdfg.arrays[name] + if isinstance(nested_desc, ParameterArray) and nested_desc.gradient: + cands.add(nested_desc.gradient) + + if len(cands) > 1: + raise ValueError("Multiple gradient buffers found for parameter " + name) + elif len(cands) == 1: + # we found a name of a gradient buffer in a nested SDFG: + # reuse the same name in the outer sdfg if there is a matching descriptor + grad_name = cands.pop() + if grad_name in sdfg.arrays: + self.gradient = grad_name + return grad_name + else: + grad_name = sdfg._find_new_name('gradient_' + name) + + # Create a gradient buffer for the array + grad_desc = copy.deepcopy(self) + grad_desc.__class__ = data.Array + grad_desc.transient = True + grad_name = sdfg.add_datadesc(grad_name, grad_desc, find_new_name=True) + self.gradient = grad_name + return grad_name + + @staticmethod + def make_parameter(sdfg: SDFG, name: str): + """ + Converts an existing array into a parameter, without copying. + + :param sdfg: the SDFG containing the array. + :param name: the name of the array. + """ + desc = sdfg.arrays[name] + if isinstance(desc, ParameterArray): + return + + new_desc = copy.deepcopy(desc) + new_desc.__class__ = ParameterArray + new_desc.gradient = None + sdfg.arrays[name] = new_desc + + +@dace.library.expansion +class ExpandBackwardPass(pm.ExpandTransformation): + environments = [] + + @staticmethod + def expansion(node: 'BackwardPass', state: SDFGState, sdfg: SDFG): + + node.validate(sdfg, state) + + in_array_name = lambda connector_name: in_edge_with_name(node, state, connector_name).data.data + + array_grad_map = {} + + access_sets = AccessSets().apply_pass(sdfg, {}) + + nsdfg = SDFG("backward_" + sdfg.label) + + # Check for other BackwardPasses that also compute the same gradients as us + node.propagate_conflicts(sdfg, state) + + # get the names of the output arrays in the forward pass + given_gradients = node.outer_names_given_gradients(state) + + array_grad_map.update(node.required_gradients) + array_grad_map.update((in_array_name(value_conn_name), grad_conn_name) + for grad_conn_name, value_conn_name in node.given_gradients.items()) + + # remove the non-grad arrays as inputs from the forward pass; + # they were also just added for control dependencies + for forward_non_grad_conn_name in node.given_gradients.values(): + for edge in list(state.in_edges_by_connector(node, forward_non_grad_conn_name)): + state.remove_edge(edge) + if state.in_degree(edge.src) + state.out_degree(edge.src) == 0: + state.remove_node(edge.src) + node.remove_in_connector(forward_non_grad_conn_name) + + gen = engine.BackwardPassGenerator(sdfg=sdfg, + given_gradients=given_gradients, + required_gradients=node.required_gradients.keys(), + backward_sdfg=nsdfg, + array_grad_map=array_grad_map, + conflicted_gradient_buffers=node._conflicted_gradients) + + _, _, required_forwarded_values = gen.backward() + + # Add zero initialization for all gradients which we are the first to compute + for outer_edge in state.out_edges(node): + gradient_we_are_writing: str = outer_edge.data.data + is_written_with_wcr = any(edge.data.wcr is not None and edge.data.data == outer_edge.src_conn + for edge, _ in nsdfg.all_edges_recursive() + if isinstance(edge, graph.MultiConnectorEdge)) + + anyone_written_before_us = autodiff_analysis.is_previously_written(sdfg, + state, + node, + gradient_we_are_writing, + access_sets=access_sets) + if not anyone_written_before_us and is_written_with_wcr: + init_grad(gradient_we_are_writing, sdfg, state) + + for name in required_forwarded_values: + # get the access to the forwarded_value + # there should only be one since we don't allow inplace modification + n = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == name] + if len(n) > 1: + raise ValueError( + "Expected only one access node for forwarded value, does the graph have in-place modification?") + elif len(n) == 0: + n = state.add_read(name) + else: + n = n[0] + + node.add_in_connector(name) + state.add_edge(n, None, node, name, sdfg.make_array_memlet(name)) + + nsdfg.validate() + + return nsdfg + + +@dace.library.node +class BackwardPass(nodes.LibraryNode): + """ + The BackwardPass library node expands to an implementation of a + BackwardPass that computes the requested gradients. + + These gradients are computed using the DaCe autograd engine. + + The gradient will be computed for each array in the output connectors. + For this, the names of the output connectors must match the name of the + array for which the gradient is to be computed. + """ + + # Global properties + implementations = { + "differentiate": ExpandBackwardPass, + } + default_implementation = "differentiate" + + given_gradients = properties.DictProperty( + key_type=str, + value_type=str, + desc="Mapping between connector names of the given gradients and the names of the arrays they correspond to.") + required_gradients = properties.DictProperty( + key_type=str, + value_type=str, + desc= + "Mapping from array name for which a gradient should be computed to the name of the connector that will receive the gradient." + ) + + _conflicted_gradients = properties.SetProperty( + element_type=str, + desc="Keys from required_gradients for which the gradients are also computed elsewhere, and thus writes to the " + " buffer need to be with write-conflict-resolution. Note: this field is automatically populated upon expansion." + ) + + def __init__(self, name, given_gradients: Dict[str, str], *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.given_gradients = given_gradients + self.required_gradients = {} + + def outer_names_given_gradients(self, state: SDFGState) -> Set[str]: + """ + Returns the names of the arrays that are passed as given gradients. + """ + in_array_name = lambda connector_name: in_edge_with_name(self, state, connector_name).data.data + return set(map(in_array_name, self.given_gradients.values())) + + def propagate_conflicts(self, sdfg: SDFG, state: SDFGState): + """ + Across this SDFG, check for other BackwardPasses that also compute the same gradients as us. + + If there are multiple BackwardPasses that compute the same gradients, update their list of conflicts. + """ + + ours = set(self.required_gradients) + + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, BackwardPass): + if node is self: + continue + conflicts = ours.intersection(node.required_gradients) + if conflicts: + self._conflicted_gradients |= conflicts + node._conflicted_gradients |= conflicts + + def validate(self, sdfg, state): + # Check that there is a correspondence between given gradients and inputs + all_inputs = set(self.in_connectors) + for given_grad, tensor_name in self.given_gradients.items(): + if given_grad not in all_inputs: + raise ValueError("Given gradient '{}' is not an input of the node".format(given_grad)) + + all_inputs.remove(given_grad) + all_inputs.remove(tensor_name) + + if all_inputs: + raise ValueError("The following in connectors were not included in given_gradients: {}".format( + ', '.join(all_inputs))) + + # Check that we are computing at least one gradient + if len(self.out_connectors) == 0: + raise ValueError("BackwardPass node '{}' does not compute any gradients".format(self.name)) diff --git a/dace/autodiff/library/torch_integration.py b/dace/autodiff/library/torch_integration.py new file mode 100644 index 0000000000..fe9e150dac --- /dev/null +++ b/dace/autodiff/library/torch_integration.py @@ -0,0 +1,39 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Hooks for PyTorch tensors to make them compatible with dace +""" +import copy + +try: + import torch + TORCH_AVAILABLE = True +except ImportError: + torch = None + TORCH_AVAILABLE = False + +from dace import data + +from dace.autodiff.library.library import ParameterArray + +if TORCH_AVAILABLE: + + def create_descriptor_tensor(self: torch.Tensor) -> data.Data: + """ + Creates a descriptor for a tensor. + If the tensor requires grad, we convert to a ParameterArray + """ + + desc = data.create_datadescriptor(self, no_custom_desc=True) + if not isinstance(desc, data.Array): + raise ValueError("Unsupported descriptor: {}".format(desc)) + + if not self.requires_grad: + return desc + + new_desc = copy.deepcopy(desc) + new_desc.__class__ = ParameterArray + new_desc.gradient = None + return new_desc + + # register with pytorch + torch.Tensor.__descriptor__ = create_descriptor_tensor diff --git a/dace/autodiff/torch.py b/dace/autodiff/torch.py new file mode 100644 index 0000000000..754a77a0a2 --- /dev/null +++ b/dace/autodiff/torch.py @@ -0,0 +1,124 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Tuple, Dict, List + +import dace +from dace import data as dt + +from dace.autodiff.backward_pass_generator import BackwardPassGenerator +from dace.autodiff.base_abc import AutoDiffException, BackwardResult + +try: + from dace.libraries.onnx.converters import clean_onnx_name + from dace.frontend.ml.onnx import ONNXModel + ONNX_AVAILABLE = True +except ImportError: + ONNX_AVAILABLE = False + clean_onnx_name = None + ONNXModel = None + + +def make_backward_function( + model, # ONNXModel type hint removed for optional import + required_grads: List[str], +) -> Tuple[dace.SDFG, dace.SDFG, BackwardResult, Dict[str, dt.Data]]: + """ Convert an ONNXModel to a PyTorch differentiable function. This method should not be used on its own. + Instead use the ``backward=True`` parameter of :class:`dace.ml.DaceModule`. + + :param model: the model to convert. + :param required_grads: the list of inputs names of the module that we must compute gradients for. + :return: A 4-tuple of forward SDFG, backward SDFG, backward result, and input arrays for + backward pass (as mapping of names to DaCe data descriptors). + """ + if not ONNX_AVAILABLE: + raise ImportError("make_backward_function requires ONNX. Install with: pip install dace[ml]") + + if len(model.sdfg.nodes()) != 1: + raise AutoDiffException("Expected to find exactly one SDFGState, found {}".format(len(model.sdfg.nodes()))) + + forward_sdfg = model.sdfg + + backward_sdfg = dace.SDFG(forward_sdfg.name + "_backward") + + gen = BackwardPassGenerator(sdfg=forward_sdfg, + given_gradients=[clean_onnx_name(name) for name in model.outputs], + required_gradients=required_grads, + backward_sdfg=backward_sdfg) + + backward_result, backward_grad_arrays, backward_input_arrays = gen.backward() + + replaced_scalars = {} + + # get the forward state + forward_state = forward_sdfg.nodes() + # A loaded pytorch model should only have one state + if len(forward_state) != 1: + raise AutoDiffException(f"Expected forward SDFG to have exactly one state, found {len(forward_state)}") + forward_state = forward_state[0] + + # get the backward state + backward_state = backward_sdfg.nodes() + # A loaded pytorch model should only have one state + if len(backward_state) != 1: + raise AutoDiffException(f"Expected backward SDFG to have exactly one state, found {len(backward_state)}") + backward_state = backward_state[0] + + for name, desc in backward_input_arrays.items(): + if name not in forward_sdfg.arrays: + raise AutoDiffException("Expected to find array with name '{}' in SDFG".format(name)) + + forward_desc = forward_sdfg.arrays[name] + # we will save this output and pass it to the backward pass + + # Views should not be forwarded. Instead the backward pass generator should forward the source of the view, + # and rebuild the sequence of required views in the backward pass. + if type(forward_desc) is dt.View: + raise AutoDiffException( + f"Cannot forward View '{name}' to backward pass. " + "Views should not be forwarded; the backward pass generator should forward " + "the source of the view and rebuild the sequence of required views in the backward pass.") + if isinstance(forward_desc, dt.Scalar): + # we can't return scalars from SDFGs, so we add a copy to an array of size 1 + fwd_arr_name, _ = forward_sdfg.add_array(name + "_array", [1], + forward_desc.dtype, + transient=False, + storage=forward_desc.storage, + find_new_name=True) + bwd_arr_name, bwd_desc = backward_sdfg.add_array(name + "_array", [1], + forward_desc.dtype, + transient=False, + storage=forward_desc.storage, + find_new_name=True) + backward_sdfg.arrays[name].transient = True + + fwd_copy_state = forward_sdfg.add_state_after(forward_state, label="copy_out_" + fwd_arr_name) + bwd_copy_state = backward_sdfg.add_state_before(backward_state, label="copy_in_" + bwd_arr_name) + fwd_copy_state.add_edge(fwd_copy_state.add_read(name), None, fwd_copy_state.add_write(fwd_arr_name), None, + dace.Memlet(name + "[0]")) + + bwd_copy_state.add_edge(bwd_copy_state.add_read(bwd_arr_name), None, bwd_copy_state.add_write(name), None, + dace.Memlet(name + "[0]")) + replaced_scalars[name] = (bwd_arr_name, bwd_desc) + else: + forward_sdfg.arrays[name].transient = False + + for orig_name, (replaced_name, replaced_desc) in replaced_scalars.items(): + del backward_input_arrays[orig_name] + backward_input_arrays[replaced_name] = replaced_desc + + for fwd_name, bwd_name in backward_result.required_grad_names.items(): + desc = backward_sdfg.arrays[bwd_name] + if isinstance(desc, dt.Scalar): + arr_name, arr_desc = backward_sdfg.add_array(bwd_name + "_array", [1], + desc.dtype, + transient=False, + storage=desc.storage, + find_new_name=True) + desc.transient = True + bwd_copy_state = backward_sdfg.add_state_after(backward_state, label="copy_out_" + bwd_name) + bwd_copy_state.add_edge(bwd_copy_state.add_read(bwd_name), None, bwd_copy_state.add_write(arr_name), None, + dace.Memlet(bwd_name + "[0]")) + backward_result.required_grad_names[fwd_name] = arr_name + + backward_sdfg.validate() + + return forward_sdfg, backward_sdfg, backward_result, backward_input_arrays diff --git a/dace/autodiff/utils.py b/dace/autodiff/utils.py new file mode 100644 index 0000000000..d1693e3ef4 --- /dev/null +++ b/dace/autodiff/utils.py @@ -0,0 +1,910 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import ast +import collections +import copy +import inspect +import numbers +import re +from typing import Dict, List, Set, Tuple, Union + +import astunparse +import sympy as sp + +# DaCe imports +import dace +import dace.sdfg.utils as utils +from dace import dtypes +from dace import data as dt +from dace.frontend.python.parser import DaceProgram +from dace.sdfg import SDFG, SDFGState, graph as dgraph, nodes as nd, state as dstate +from dace.sdfg.state import LoopRegion + +# Autodiff imports +from dace.autodiff.base_abc import AutoDiffException, BackwardContext, BackwardResult + + +def forward_in_desc_with_name(forward_node: nd.Node, context: BackwardContext, name: str) -> dt.Data: + """Find the descriptor of the data that connects to input connector ``name``. + + :param forward_node: The node in the forward pass. + :param context: The backward context containing forward SDFG and state information. + :param name: The input connector name to find the descriptor for. + :return: The data descriptor that connects to the specified connector. + """ + return utils.in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, name) + + +def forward_out_desc_with_name(forward_node: nd.Node, context: BackwardContext, name: str) -> dt.Data: + """Find the descriptor of the data that connects to output connector ``name``. + + :param forward_node: The node in the forward pass. + :param context: The backward context containing forward SDFG and state information. + :param name: The output connector name to find the descriptor for. + :return: The data descriptor that connects to the specified connector. + """ + return utils.out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, name) + + +def add_backward_desc_for_connector(backward_sdfg: dace.SDFG, forward_node: nd.Node, context: BackwardContext, + connector: str, input: bool) -> str: + """Adds the backward array for the connector of ``forward_node``. + + :param backward_sdfg: The SDFG to add the backward array descriptor to. + :param forward_node: The forward node with the connector to create a descriptor for. + :param context: The backward context containing forward SDFG and state information. + :param connector: The connector name on the forward node. + :param input: True if the connector is an input, False if it's an output. + :return: The name of the newly added gradient array in ``backward_sdfg``. + """ + + if input: + edge = utils.in_edge_with_name(forward_node, context.forward_state, connector) + else: + edge = utils.out_edge_with_name(forward_node, context.forward_state, connector) + arr_name = edge.data.data + + forward_desc = context.forward_sdfg.arrays[arr_name] + + new_desc = copy.deepcopy(forward_desc) + new_desc.transient = False + return backward_sdfg.add_datadesc(arr_name + "_grad", new_desc, find_new_name=True) + + +def add_backward_desc(backward_sdfg: dace.SDFG, forward_sdfg: dace.SDFG, forward_desc: dt.Data, + forward_name: str) -> str: + """Adds the backward array for the given descriptor. + + :param backward_sdfg: The SDFG to add the backward array descriptor to. + :param forward_sdfg: The forward SDFG used for finding unique names. + :param forward_desc: The data descriptor of the forward array. + :param forward_name: A name for the forward array (doesn't have to match its actual name). + :return: The name of the newly added gradient array in ``backward_sdfg``. + """ + backward_name = dt.find_new_name(forward_name + "_grad", forward_sdfg.arrays) + new_desc = copy.deepcopy(forward_desc) + new_desc.transient = False + return backward_sdfg.add_datadesc(backward_name, new_desc) + + +def add_empty_sdfg_for_node(forward_node: nd.Node, required_descriptors: List[str], + context: BackwardContext) -> Tuple[nd.NestedSDFG, BackwardResult]: + """ Given a node, return an SDFG that can be used as a nested SDFG expansion for that node. + + ``required_descriptors`` may contain: + * Inputs/outputs of the forward node (these will be hooked up as required) + * Inputs/outputs of the forward node with the ``_grad`` suffix. These will be hooked up + as the gradients of the corresponding inputs/outputs. + + The descriptors will be initialized using the descriptors connected to edges of the + forward node. + + :param forward_node: the node in the forward pass + :param required_descriptors: A list of descriptors that should be added to the SDFG. + :param context: the backward context + :return: the nested SDFG and backward result for the forward node + """ + + nsdfg = dace.SDFG(forward_node.label + "_backward_expansion") + + def _get_fwd_descriptor(name): + """Returns the descriptor and whether it is an input""" + if name in forward_node.out_connectors: + return forward_out_desc_with_name(forward_node, context, name), False + elif name in forward_node.in_connectors: + return forward_in_desc_with_name(forward_node, context, name), True + + raise ValueError(f"Could not find {name} in inputs or outputs of {forward_node}") + + outputs_to_connect_from_forward = [] + + result = BackwardResult.empty() + inputs = set() + outputs = set() + + for name in required_descriptors: + if name.endswith("_grad"): + # hook this up as a gradient + desc, is_input = _get_fwd_descriptor(name[:-5]) + if is_input: + result.required_grad_names[name[:-5]] = name + else: + result.given_grad_names[name[:-5]] = name + # input grads are outputs of the backward node + if is_input: + outputs.add(name) + else: + inputs.add(name) + else: + desc, is_input = _get_fwd_descriptor(name) + if not is_input: + outputs_to_connect_from_forward.append(name) + inputs.add(name) + ndesc = copy.deepcopy(desc) + ndesc.transient = False + nsdfg.add_datadesc(name, ndesc) + + bwd_node = context.backward_state.add_nested_sdfg(nsdfg, inputs, outputs) + for output in outputs_to_connect_from_forward: + connect_output_from_forward(forward_node, bwd_node, context, output) + + return bwd_node, result + + +def backward_program_for_node(program, context: BackwardContext, + forward_node: nd.Node) -> Tuple[nd.Node, BackwardResult]: + """ Expand a function to the backward function for a node. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + Gradient parameters should be the name of the forward parameter, appended with _grad. For these arguments the + data descriptors will match the data descriptors of the inputs/outputs they correspond to. + """ + + input_names = set(inp.name for inp in forward_node.schema.inputs) + output_names = set(outp.name for outp in forward_node.schema.outputs) + + if input_names.intersection(output_names): + # this is currently the case for only one onnx op + raise ValueError("program_for_node cannot be applied on nodes of this type;" + " '{}' is both an input and an output".format(next(input_names.intersection(output_names)))) + + def name_without_grad_in(name, collection): + return name[-5:] == "_grad" and name[:-5] in collection + + params = inspect.signature(program).parameters + + backward_result = BackwardResult.empty() + + inputs = {} + outputs = {} + for name, _ in params.items(): + if name in input_names: + inputs[name] = copy.deepcopy(forward_in_desc_with_name(forward_node, context, name)) + + elif name_without_grad_in(name, input_names): + outputs[name] = copy.deepcopy(forward_in_desc_with_name(forward_node, context, name[:-5])) + backward_result.required_grad_names[name[:-5]] = name + + elif name in output_names: + inputs[name] = copy.deepcopy(forward_out_desc_with_name(forward_node, context, name)) + + elif name_without_grad_in(name, output_names): + inputs[name] = copy.deepcopy(forward_out_desc_with_name(forward_node, context, name[:-5])) + backward_result.given_grad_names[name[:-5]] = name + + else: + raise ValueError("'{}' was not found as an input or output for {}".format(name, forward_node.schema.name)) + + program.__annotations__ = {**inputs, **outputs} + + sdfg = DaceProgram(program, (), {}, False, dace.DeviceType.CPU).to_sdfg() + + result_node = context.backward_state.add_nested_sdfg(sdfg, set(inputs), set(outputs)) + + return result_node, backward_result + + +def connect_output_from_forward(forward_node: nd.Node, backward_node: nd.Node, context: BackwardContext, + output_connector_name: str): + """ Connect an output of the forward node as an input to the backward node. This is done by forwarding the array + from the forward pass. + + Conceptually, this is similar to pytorch's ctx.save_for_backward. + + :param forward_node: the node in the forward pass. + :param backward_node: the node in the backward pass. + :param context: the backward context. + :param output_connector_name: the name of the connector on the backward pass. The output of that connector will + be forwarded to the connector of the same name on the backward node. + """ + output_edge = utils.out_edge_with_name(forward_node, context.forward_state, output_connector_name) + + # add the array of the output to backward_input_arrays that it will be forwarded by the autodiff engine + output_arr_name = output_edge.data.data + if output_arr_name not in context.backward_generator.backward_input_arrays: + data_desc = copy.deepcopy(context.forward_sdfg.arrays[output_arr_name]) + context.backward_generator.backward_input_arrays[output_arr_name] = data_desc + + if context.backward_generator.separate_sdfgs: + data_desc.transient = False + context.backward_sdfg.add_datadesc(output_arr_name, data_desc) + + read = context.backward_state.add_read(output_arr_name) + else: + cand = [ + n for n, _ in context.backward_state.all_nodes_recursive() + if isinstance(n, nd.AccessNode) and n.data == output_arr_name + ] + read = cand[0] + context.backward_state.add_edge(read, None, backward_node, output_connector_name, copy.deepcopy(output_edge.data)) + + +def cast_consts_to_type(code: str, dtype: dace.typeclass) -> str: + """Convert a piece of code so that constants are wrapped in casts to ``dtype``. + + For example:: + + x * (3 / 2) + + becomes:: + + x * (dace.float32(3) / dace.float32(2)) + + This is only done when it is required due to a Div operator to ensure proper + type casting in mathematical expressions during automatic differentiation. + + :param code: The code string to convert. + :param dtype: The DaCe typeclass to cast constants to. + :return: A string of the converted code with properly typed constants. + """ + + class CastConsts(ast.NodeTransformer): + + def __init__(self): + self._in_div_stack = collections.deque() + + def visit_Num(self, node): + if self._in_div_stack: + return ast.copy_location( + ast.parse(f"dace.{dtype.to_string()}({astunparse.unparse(node)})").body[0].value, node) + else: + return self.generic_visit(node) + + def visit_BinOp(self, node: ast.BinOp): + if node.op.__class__.__name__ == "Pow": + # within pow, we don't need to cast unless there is a new div + old_stack = self._in_div_stack + # reset the stack + self._in_div_stack = collections.deque() + node = self.generic_visit(node) + self._in_div_stack = old_stack + return node + + elif node.op.__class__.__name__ == "Div": + self._in_div_stack.append(None) + node = self.generic_visit(node) + self._in_div_stack.popleft() + return node + else: + return self.generic_visit(node) + + def visit_Constant(self, node): + if self._in_div_stack: + return ast.copy_location( + ast.parse(f"dace.{dtype.to_string()}({astunparse.unparse(node)})").body[0].value, node) + else: + return self.generic_visit(node) + + return astunparse.unparse(CastConsts().visit(ast.parse(code))) + + +def init_grad(data: str, sdfg: SDFG, current_state: SDFGState) -> None: + """Add a state where ``data`` is initialized with zero. + + This function creates a new state before the current state that initializes + the gradient array with zeros. It handles different storage types (CPU/GPU) + and array types appropriately. + + :param data: The name of the data array to initialize. + :param sdfg: The SDFG to add the initialization state to. + :param current_state: The current state; initialization will be done before this state. + :raises ValueError: If the storage type is not supported. + :raises AutoDiffException: If the data descriptor type is not supported. + """ + arr = sdfg.arrays[data] + + state = sdfg.add_state_before(current_state, label="init_" + data) + + scalar = 0 + if dtypes.can_access(dtypes.ScheduleType.CPU_Multicore, arr.storage): + cuda = False + elif dtypes.can_access(dtypes.ScheduleType.GPU_Default, arr.storage): + cuda = True + else: + raise ValueError(f"Unsupported storage {arr.storage}") + + if isinstance(arr, (dt.Array, dt.Scalar)): + state.add_mapped_tasklet( + "_init_" + data + "_", { + "i{}".format(i): "0:{}".format(shape) + for i, shape in enumerate(arr.shape) + }, {}, + "__out = {}".format(scalar), + {"__out": dace.Memlet.simple(data, ", ".join("i{}".format(i) for i in range(len(arr.shape))))}, + schedule=dtypes.ScheduleType.GPU_Device if cuda else dtypes.ScheduleType.Default, + external_edges=True) + elif type(arr) is dt.View: + # not need to initialize: the viewed array will always be visited + # (since a view can never be a required grad), and thus the viewed array will be initialized. + pass + else: + raise AutoDiffException("Unsupported data descriptor {}".format(arr)) + + +def extract_indices(expression: str) -> Dict[str, List[str]]: + """Extracts indexed array names and their indices from a given string expression. + + This function uses regular expressions to find patterns like "array[i, j, k]" + and returns a dictionary mapping array names to their index lists. + + :param expression: The string expression to analyze. + :return: A dictionary mapping array names to lists of their indices. + + Example:: + + >>> extract_indices("a[i, j] + b[k]") + {'a': ['i', 'j'], 'b': ['k']} + """ + # Regular expression to match the array names and their indices + pattern = r"(\w+)\[((?:\w+,?\s*)+)\]" + + # Find all matches in the given expression + matches = re.findall(pattern, expression) + + # Create a dictionary to store the arrays and their indices + index_map = {} + for name, indices in matches: + # Split indices by comma and remove any extra spaces + index_list = [index.strip() for index in indices.split(',')] + index_map[name] = index_list + + return index_map + + +def code_to_exprs(code: str, tasklet: nd.Tasklet, + symbols: List[str]) -> Tuple[Dict[str, sp.Expr], Dict[str, List[str]]]: + """ Convert a python string to a set of (simplified) symbolic sympy expressions. Currently, this + supports only code consisting of assignment statements. + + :param code: the code to convert + :param inputs: the inputs (i.e. the defined variables) for the code + :param outputs: the outputs to generate simplified expressions for + :return: map from outputs to symbolic expressions + """ + + inputs: List[str] = list(tasklet.in_connectors) + outputs: List[str] = list(tasklet.out_connectors) + + # Add the definition of global constant symbols that are presen in the code + # Prepare the Symbol declaration code + symbol_code = "" + for symb in symbols: + symbol_code += f" {symb} = sp.symbols('{symb}')\n" + + # We prepare a map of indexed objects and their indices + indexed_objects_map = extract_indices(code) + + # For now, make sure none of the outputs are indexed objects + indexed_outputs = [out for out in outputs if out in indexed_objects_map] + if indexed_outputs: + raise AutoDiffException(f"Indexed outputs are not currently supported: {indexed_outputs}") + + # Add the definition of indexed objects to the sympy code + indexed_objects_code = "" + for conn in inputs + outputs: + if (conn in inputs and isinstance(tasklet.in_connectors[conn], dace.dtypes.pointer) + or (conn in outputs and isinstance(tasklet.out_connectors[conn], dace.dtypes.pointer))): + if conn not in indexed_objects_map: + raise AutoDiffException(f"Expected connector '{conn}' to be in indexed objects map for pointer type") + indexed_objects_code += f" {conn} = sp.IndexedBase('{conn}')\n" + for idx in indexed_objects_map[conn]: + indexed_objects_code += f" {idx} = sp.symbols('{idx}', cls=sp.Idx)\n" + + code_fn = """ +def symbolic_execution({}): + # define functions from cmath.h + from sympy import exp, log + def log2(x): + return log(x, 2) + def log10(x): + return log(x, 10) + from sympy import sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh + from sympy import sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh + from sympy import Pow as pow, sqrt + from sympy import sign, floor, ceiling as ceil, Abs as abs, Abs as fabs + from sympy import Max as max, Min as min + from sympy import Max as fmax, Min as fmin + from sympy import erf + import sympy as sp +{} +{} +{} + return {} + """ + code_fn = code_fn.format( + ", ".join(inputs), + symbol_code, + indexed_objects_code, + "\n".join(" " + line.strip() for line in code.split("\n")), + ", ".join(outputs), + ) + + # Clean out type conversions from the code + code_fn = re.sub(r"dace\.(float32|int32|float64|int64)\((.*?)\)", r"\2", code_fn) + + try: + # need to have dace so things like `dace.float32(1)` work + temp_globals = {'dace': dace} + exec(code_fn, temp_globals) + + # no idea why, but simply calling symbolic_execution doesn't work + results = temp_globals["symbolic_execution"](*[sp.symbols(inp) for inp in inputs]) + + if len(outputs) > 1: + # make sure that everything is a sympy expression + for i, res in enumerate(results): + if not isinstance(res, sp.Expr): + results[i] = sp.sympify(res) + return dict(zip(outputs, results)), indexed_objects_map + else: + # make sure that everything is a sympy expression + if not isinstance(results, sp.Expr): + results = sp.sympify(results) + return {outputs[0]: results}, indexed_objects_map + except Exception as e: + raise AutoDiffException( + "Exception occurred while attempting to symbolically execute code:\n{}".format(code)) from e + + +def is_int_eq_value(value, target_value: int) -> bool: + if isinstance(value, numbers.Integral): + return value == target_value + + if len(value.free_symbols) > 0 or int(value) != target_value: + return False + + return True + + +def invert_map_connector(conn: str) -> str: + if conn.startswith("IN"): + return "OUT" + conn[2:] + elif conn.startswith("OUT"): + return "IN" + conn[3:] + else: + raise AutoDiffException("Could not parse map connector '{}'".format(conn)) + + +def path_src_node_in_subgraph(edge: dgraph.MultiConnectorEdge, subgraph: dstate.StateSubgraphView) -> bool: + path_src = subgraph.memlet_path(edge)[0].src + return path_src in subgraph.nodes() + + +def get_read_only_arrays(sdfg: SDFG) -> Set[str]: + """Get the arrays that are only read in SDFG. + + This function identifies arrays that are never written to (only have outgoing + edges with data or only empty memlets on incoming edges). + + :param sdfg: The SDFG to analyze. + :return: A set of array names that are read-only in the SDFG. + """ + written_to_arrays = set() + for node, parent in sdfg.all_nodes_recursive(): + if isinstance(node, nd.AccessNode): + if parent.in_degree(node) > 0 and any(not e.data.is_empty() for e in parent.in_edges(node)): + written_to_arrays.add(node.data) + + read_only_arrays = set(sdfg.arrays.keys()) - written_to_arrays + return read_only_arrays + + +def get_state_topological_order(graph) -> List[SDFGState]: + """ + Returns the SDFG states in topological order. + """ + all_nodes = list(utils.dfs_topological_sort(graph, graph.source_nodes())) + state_order = [] + for node in all_nodes: + if isinstance(node, SDFGState): + state_order.append(node) + elif isinstance(node, LoopRegion): + loop_state_order = get_state_topological_order(node) + state_order.extend(loop_state_order) + else: + raise AutoDiffException( + f"Unsupported node type {node} at the highest level of the SDFG while getting the state order") + + # All states in the graph need to be present in the state order + if isinstance(graph, SDFG) and set(state_order) != set(graph.states()): + raise AutoDiffException("Could not find all states of the SDFG in the state order") + return state_order + + +def shape_has_symbols_to_replace(sdfg: SDFG, shape: Union[str, sp.Symbol, sp.Expr]) -> bool: + """ + Check if the shape dimension passed as a parameter has a symbol that needs to be replaced. + We do not replace global SDFG symbols but rather the loop indices only. + """ + defined_symbols = sdfg.free_symbols | set(sdfg.arg_names) + if isinstance(shape, str): + shape = dace.symbolic.pystr_to_symbolic(shape) + return dace.symbolic.issymbolic(shape, defined_symbols) + + +def get_loop_end(start: str, end: str, loop: LoopRegion) -> str: + """ + Get the smallest and largest index of a loop given the start and end values. + This is an attempt at estimating the number of iterations of the loop. + """ + start_sym = dace.symbolic.pystr_to_symbolic(start) + end_sym = dace.symbolic.pystr_to_symbolic(end) + if not dace.symbolic.issymbolic(start_sym) and not dace.symbolic.issymbolic(end_sym): + int_start, int_end = int(start_sym), int(end_sym) + if int_start < int_end: + # Increasing loop + largest_index = int_end + smallest_index = int_start + else: + # Decreasing loop e.g., range(6, -1, -1) + # Since the start will be the first index there are start+1 iterations + largest_index = int_start + 1 + smallest_index = int_end + else: + # We check using the update statement + change = analyze_loop_change(loop.update_statement.as_string, loop.loop_variable) + if change == "increase": + # Increasing loop + largest_index = end + smallest_index = start + else: + # Decreasing loop + # Since the start will be the first index there are start+1 iterations + largest_index = start + "+1" + smallest_index = end + + return smallest_index, largest_index + + +def analyze_loop_change(code: str, loop_variable: str) -> str: + """Analyze if the given loop variable in the provided code increases or decreases. + + :param code: The Python code to analyze. + :param loop_variable: The name of the loop variable to analyze. + :return: ``'increase'``, ``'decrease'``, or ``'unknown'``. + """ + tree = ast.parse(code) + change_type = "unknown" + + for node in ast.walk(tree): + # Look for assignment statements + if isinstance(node, ast.Assign): + # Ensure the assignment targets the loop variable + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + target = node.targets[0].id + if target == loop_variable and isinstance(node.value, ast.BinOp): + # Check for `loop_variable = loop_variable + ...` + if isinstance(node.value.left, ast.Name) and node.value.left.id == loop_variable: + # Analyze the right-hand side for increase or decrease + rhs = node.value.right + if isinstance(rhs, ast.UnaryOp) and isinstance(rhs.op, ast.USub): # Unary negative + if isinstance(rhs.operand, ast.Constant) and isinstance(rhs.operand.value, (int, float)): + change_type = "decrease" + elif isinstance(rhs, ast.UnaryOp) and isinstance(rhs.op, ast.UAdd): # Unary positive + if isinstance(rhs.operand, ast.Constant) and isinstance(rhs.operand.value, (int, float)): + change_type = "increase" + elif isinstance(rhs, ast.Constant) and isinstance(rhs.value, (int, float)): + change_type = "increase" if rhs.value > 0 else "decrease" + if change_type == "unknown": + raise AutoDiffException(f"Could not determine loop variable change in code: {code}") + return change_type + + +def get_map_nest_information( + edges_list: List[dstate.MultiConnectorEdge]) -> Tuple[List, List[str], List, Dict[str, Tuple]]: + """ + """ + # First, get the shape of the new array + shape_list = [] + + # We will also need the starting range of the maps in the path + start_range = [] + + # And the names of the parameters of the maps in the path + param_list = [] + + for e in edges_list: + edge_src = e.src + if isinstance(edge_src, nd.MapEntry): + for rng in edge_src.map.range.ranges: + # the range contains the last index in the loop + # while we want the size so we add 1 + shape_list.append(rng[1] + 1) + start_range.append(rng[0]) + for par in edge_src.map.params: + param_list.append(par) + + if not (len(param_list) == len(shape_list) == len(start_range)): + raise AutoDiffException( + f"Mismatched lengths: params={len(param_list)}, shapes={len(shape_list)}, ranges={len(start_range)}") + + # Create a dictionary mapping parameters to their start and end ranges + param_dict = {param: (start, end) for param, start, end in zip(param_list, start_range, shape_list)} + return start_range, param_list, shape_list, param_dict + + +def get_all_path_edges(state: SDFGState, source: nd.Node, + starting_edge: dgraph.MultiConnectorEdge) -> List[dgraph.MultiConnectorEdge]: + """ + We will start from the target node and go back until we reach the destination. + Starting edge should be an in node + """ + all_edges = [] + memlet_path = state.memlet_path(starting_edge) + all_edges += memlet_path + first_source = memlet_path[0].src + if first_source == source: + return all_edges + + # If there is only one edge coming to the first node + if state.in_degree(first_source) == 1: + edge = state.in_edges(first_source)[0] + memlet_path = state.memlet_path(edge) + all_edges += memlet_path + first_source = memlet_path[0].src + if first_source == source: + return all_edges + + raise AutoDiffException("Can't easily find path. Upgrade function.") + + +def extract_conditional_expressions(tasklet_node: nd.Tasklet) -> Tuple[str, str, str]: + """ + Given a conditional tasklet node, extract the if and else expressions and return them with the conditional. + The else statement could be None in case there is only an if statement. The current supported formats are the following: + 1 - if cond: + out = expression_1 + which would return ("out = expression_1", None, "if cond") + 2- out = expression_1 if cond else expression 2 + """ + + tasklet_code = tasklet_node.code.as_string + + # check which type of assignment this is + if ":" in tasklet_code: + # get the conditional input connector through regular expression matching + matches = re.search(r"if (.)*:", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find 'if' statement in conditional tasklet code: {tasklet_code}") + conditional = matches.group() + + # remove the conditional from the code to get the expression + if_statement = tasklet_code.replace(conditional, "") + if_statement = if_statement.replace("\n", "") + + # remove indentation + if_statement = if_statement[3:] + + # extract the in connector only + conditional = conditional.replace(":", "") + conditional = conditional.replace("if ", "") + if conditional not in tasklet_node.in_connectors: + raise AutoDiffException( + f"Conditional '{conditional}' not found in tasklet input connectors: {list(tasklet_node.in_connectors.keys())}" + ) + + else_statement = None + + # match the out connector + matches = re.search(r"^(.)* =", if_statement) + if not matches: + raise AutoDiffException(f"Could not find output assignment in if statement: {if_statement}") + out_connector = matches.group() + + # remove the assignment from the if statement + if_statement = if_statement.replace(out_connector, "") + + # extract the out connector only + out_connector = out_connector[1:].replace(" =", "") + + else: + # get the conditional input connector through regular expression matching + matches = re.search(r"if (.)* else", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find 'if...else' statement in conditional tasklet code: {tasklet_code}") + conditional = matches.group() + + # extract the in connector only + conditional = conditional.replace("if ", "") + conditional = conditional.replace(" else", "") + + if conditional not in tasklet_node.in_connectors: + raise AutoDiffException( + f"Conditional '{conditional}' not found in tasklet input connectors: {list(tasklet_node.in_connectors.keys())}" + ) + + # get the if statement by matching what comes before the if until we encounter a parenthesis or = + matches = re.search(r"= \((.)* if", tasklet_code) + if not matches: + # try without the parenthesis + matches = re.search(r"= (.)* if", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find if expression pattern in tasklet code: {tasklet_code}") + + if_statement = matches.group() + + # extract the in statement only + if_statement = if_statement.replace("= (", "") + if_statement = if_statement.replace(" if", "") + + # get the else statement by matching the else and what comes after it until we encounter a parenthesis + matches = re.search(r"else (.)*\)", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find else expression pattern in tasklet code: {tasklet_code}") + else_statement = matches.group() + + # extract the in statement only + else_statement = else_statement.replace("else ", "") + + # remove the last closing parenthesis if it exists + if else_statement.endswith(")"): + else_statement = else_statement[:-1] + + # match the out connector + matches = re.search(r"^(.)* =", tasklet_code) + if not matches: + raise AutoDiffException(f"Could not find output assignment in tasklet code: {tasklet_code}") + out_connector = matches.group() + + # extract the in statement only + out_connector = out_connector.replace(" =", "") + + # sanity check this should be in the out connectors of the tasklet + if out_connector not in tasklet_node.out_connectors: + raise AutoDiffException( + f"Output connector '{out_connector}' not found in tasklet output connectors: {list(tasklet_node.out_connectors.keys())}" + ) + + # create the return expressions + if_expression = f"{out_connector} = {if_statement}" + else_expression = f"{out_connector} = {else_statement}" if else_statement else None + + return if_expression, else_expression, conditional + + +def check_edges_type_in_state(subgraph: dstate.StateSubgraphView) -> None: + """ + Check if all the edges in this state are of type float, int, or boolean. + """ + for edge, parent_subgraph in subgraph.all_edges_recursive(): + if isinstance(parent_subgraph, SDFGState): + parent_sdfg = parent_subgraph.parent + elif isinstance(parent_subgraph, dstate.StateSubgraphView): + parent_sdfg = parent_subgraph.graph.parent + elif isinstance(parent_subgraph, SDFG) or isinstance(parent_subgraph, LoopRegion): + # if there are any fancy things on the interstate edges we should probably throw an error + continue + else: + raise AutoDiffException("Unexpected subgraph structure") + + if edge.data.data: + edge_type = parent_sdfg.arrays[edge.data.data].dtype + if edge_type in [dace.string]: + raise AutoDiffException( + f"Expected Subgraph to differentiate to only contain float, int, and bool edges, but data {edge.data}" + f" on edge {edge} has type {edge_type}") + + +def state_within_loop(forward_state: SDFGState) -> Tuple[bool, LoopRegion]: + """ + Check if this state will be executed several times within a loop. + We check if any of the parents of this state is a loop region. + """ + parent = forward_state.parent_graph + while parent is not None: + if isinstance(parent, LoopRegion): + return True, parent + parent = parent.parent_graph + return False, None + + +class SympyCleaner(ast.NodeTransformer): + + def visit_Name(self, node): + if node.id == "pi": + return ast.copy_location(ast.parse("dace.math.pi").body[0].value, node) + return self.generic_visit(node) + + +def extract_loop_region_info(loop: LoopRegion) -> Tuple[str, str]: + """ + Use regular expression matching to extract the start and end of the loop region. + We only treat regular for-loops with incrementation and decrementation updates. + """ + + # Extract the loop iterator + it = loop.loop_variable + + # Extract the end of the loop from the conditional statement + conditional = loop.loop_condition.as_string + + stride_sign = get_stride_sign(loop) + + # If the stride is positive + if stride_sign > 0: + conditional_expression = fr".*{it} < .*" + else: + # If the stride is negative + conditional_expression = fr".*{it} > .*" + + # Match the conditional using regular expressions + matches = re.search(conditional_expression, conditional) + if not matches: + raise AutoDiffException(f"Could not match conditional expression '{conditional_expression}' in '{conditional}'") + expression = matches.group() + matches_inner = re.search(conditional_expression[:-2], conditional) + if not matches_inner: + raise AutoDiffException( + f"Could not match conditional pattern '{conditional_expression[:-2]}' in '{conditional}'") + expression_to_remove = matches_inner.group() + end = expression.replace(expression_to_remove, "") + + # TODO: need more generalized solution for functions in the loop bounds + if "floor" not in conditional: + # There is no function call in the statement, remove parenthesis + end = end.replace("(", "") + end = end.replace(")", "") + end = end.replace(" ", "") + else: + if expression_to_remove.startswith("(") and not expression_to_remove.endswith(")") and expression.endswith(")"): + # Remove extra parenthesis + end = end[:-1] + + # Get the start from the initialization code + init_code = loop.init_statement.as_string + matches = re.search(fr".*{it} = .*", init_code) + if not matches: + raise AutoDiffException(f"Could not find initialization pattern for loop variable '{it}' in '{init_code}'") + expression = matches.group() + matches = re.search(fr"{it} =", init_code) + if not matches: + raise AutoDiffException(f"Could not find assignment pattern for loop variable '{it}' in '{init_code}'") + expression_to_remove = matches.group() + start = expression.replace(expression_to_remove, "") + + # Remove parenthesis and space + start = start.replace("(", "") + start = start.replace(")", "") + start = start.replace(" ", "") + + return start, end + + +def get_stride_sign(loop: LoopRegion) -> int: + """Check if the stride for this loop is positive or negative. + + :param loop: The loop region to analyze. + :return: ``1`` if the stride is positive, ``-1`` if negative. + :raises AutoDiffException: If the loop has an unsupported structure. + """ + if loop.update_statement is None: + raise AutoDiffException("While loops are not yet supported in DaCe AD") + update_statement = loop.update_statement.as_string + if "-" in update_statement: + return -1 + if "+" in update_statement: + return 1 + + # unsupported loop structure + raise AutoDiffException(f"Expected the loop region {loop.label} to have a regular update statement." + f" Instead got: {update_statement}") diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 3ccbb56dc6..8ad47f212e 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -176,7 +176,7 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: sdfg.save(f'{tmp_dir}/test.sdfg', hash=False) sdfg2 = SDFG.from_file(f'{tmp_dir}/test.sdfg') sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False) - print('Testing SDFG serialization...') + if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'): with open(f'{tmp_dir}/test.sdfg', 'r') as f1: with open(f'{tmp_dir}/test2.sdfg', 'r') as f2: diff --git a/dace/codegen/common.py b/dace/codegen/common.py index d8524eacbc..f5bbf445a2 100644 --- a/dace/codegen/common.py +++ b/dace/codegen/common.py @@ -171,3 +171,14 @@ def get_gpu_runtime() -> gpu_runtime.GPURuntime: 'environment variable to point to the libraries.') return gpu_runtime.GPURuntime(backend, libpath) + + +def platform_library_name(libname: str) -> str: + """ Get the filename of a library. + + :param libname: the name of the library. + :return: the filename of the library. + """ + prefix = config.Config.get('compiler', 'library_prefix') + suffix = config.Config.get('compiler', 'library_extension') + return f"{prefix}{libname}.{suffix}" diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 7871962cad..b451668831 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -479,7 +479,8 @@ def ndcopy_to_strided_copy( """ # Cannot degenerate tiled copies - if any(ts != 1 for ts in subset.tile_sizes): + # In the case where subset is of type Indices, there are no tile_sizes + if hasattr(subset, 'tile_sizes') and any(ts != 1 for ts in subset.tile_sizes): return None # If the copy is contiguous, the difference between the first and last diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index ef9b42fe33..5e71cbb074 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -972,7 +972,6 @@ def write_and_resolve_expr(self, vec_prefix = 'v' vec_suffix = f'<{dtype.veclen}>' dtype = dtype.base_type - func = f'{vec_prefix}reduce{atomic}{vec_suffix}' # Special call for detected reduction types diff --git a/dace/frontend/ml/__init__.py b/dace/frontend/ml/__init__.py new file mode 100644 index 0000000000..6e6305d8f9 --- /dev/null +++ b/dace/frontend/ml/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +try: + from .torch import DaceModule +except ImportError: + DaceModule = None + +try: + from .onnx import ONNXModel +except ImportError: + ONNXModel = None + +__all__ = ['DaceModule', 'ONNXModel'] diff --git a/dace/frontend/ml/onnx/__init__.py b/dace/frontend/ml/onnx/__init__.py new file mode 100644 index 0000000000..aa0c16bf05 --- /dev/null +++ b/dace/frontend/ml/onnx/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from .importer import ONNXModel + +__all__ = ['ONNXModel'] diff --git a/dace/frontend/ml/onnx/importer.py b/dace/frontend/ml/onnx/importer.py new file mode 100644 index 0000000000..226b3e821f --- /dev/null +++ b/dace/frontend/ml/onnx/importer.py @@ -0,0 +1,794 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Model Importer for DaCe. + +This module provides the ONNXModel class, which is the main entry point for +importing ONNX models into DaCe. It handles the complete pipeline of: + +1. **Model Loading**: Loading ONNX models from files or protobuf objects +2. **Model Simplification**: Applying onnx-simplifier for optimization +3. **Shape Inference**: Computing tensor shapes symbolically or concretely +4. **Graph Conversion**: Converting ONNX graph to DaCe SDFG +5. **Weight Management**: Handling model parameters and initializers +6. **Compilation**: Compiling the SDFG to executable code +7. **Execution**: Running the model with NumPy or PyTorch tensors + +Key Features: +- Automatic shape inference for dynamic models +- Support for both CPU and CUDA execution +- Integration with PyTorch for seamless tensor conversion +- Configurable optimization levels +- Weight initialization and parameter management +- Support for nested models and subgraphs + +Typical Workflow: + >>> import onnx + >>> from dace.frontend.ml.onnx import ONNXModel + >>> + >>> # Load ONNX model + >>> onnx_model = onnx.load("model.onnx") + >>> dace_model = ONNXModel("my_model", onnx_model) + >>> + >>> # Run inference + >>> import numpy as np + >>> input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) + >>> output = dace_model(input_data) + +The module also provides utility functions for: +- Type conversion between NumPy, PyTorch, and ONNX types +- Model validation and checking +- Shape inference helpers +- Weight loading and initialization + +Note: + This is a large module (900+ lines) that handles multiple concerns. + Consider the architectural recommendations in the code review for + potential refactoring into smaller, focused modules. +""" + +import collections +import copy +import tempfile +from itertools import chain, repeat +from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Union + +import numpy as np + +# PyTorch is optional (only needed for tensor conversion features) +try: + import torch + TORCH_AVAILABLE = True +except ImportError: + torch = None + TORCH_AVAILABLE = False + +# ONNX is mandatory for this module +try: + import onnx + import onnx.checker + from onnx import numpy_helper +except ImportError as e: + raise ImportError("ONNX library is required. Install with: pip install dace[ml]") from e + +# ONNXRuntime for symbolic shape inference +try: + from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + ONNXRUNTIME_AVAILABLE = True +except ImportError: + SymbolicShapeInference = None + ONNXRUNTIME_AVAILABLE = False + +# onnxsim is optional (only needed for model simplification) +try: + import onnxsim + ONNXSIM_AVAILABLE = True +except ImportError: + onnxsim = None + ONNXSIM_AVAILABLE = False + +import dace +from dace import config, SDFG, SDFGState, data as dt, dtypes, nodes +from dace.codegen import compiled_sdfg +from dace.frontend.python import parser +from dace.sdfg import utils as sdfg_utils +from dace.symbolic import pystr_to_symbolic +from dace.transformation.onnx import auto_optimize_onnx as auto_opt +from dace.transformation.onnx import expand_onnx_nodes as onnx_node_expander + +from dace.libraries.onnx.converters import clean_onnx_name, convert_attribute_proto, onnx_tensor_type_to_typeclass +from dace.libraries.onnx.nodes.onnx_op_registry import get_onnx_node, has_onnx_node +from dace.libraries.onnx.schema import ONNXParameterType + +#: Mapping from NumPy dtypes to PyTorch dtypes for tensor conversion +if TORCH_AVAILABLE: + numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128 + } + + #: Reverse mapping from PyTorch dtypes to NumPy dtypes + torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()} +else: + numpy_to_torch_dtype_dict = {} + torch_to_numpy_dtype_dict = {} + + +def _nested_HasField(obj, full_attr: str) -> bool: + """ + Check if a protobuf object has a nested field. + + This function performs a nested hasattr check by traversing dot-separated + attribute names on a protobuf object. + + :param obj: The protobuf object to check. + :param full_attr: Dot-separated attribute path (e.g., "graph.node"). + :return: True if all attributes in the path exist, False otherwise. + + Example:: + + >>> _nested_HasField(model, "graph.node") + True + """ + attrs = full_attr.split(".") + for attr in attrs: + if obj.HasField(attr): + obj = getattr(obj, attr) + else: + return False + return True + + +def infer_shapes_onnx_model(model: onnx.ModelProto, auto_merge: bool = False) -> onnx.ModelProto: + """ + Perform shape inference on an ONNX model using ONNXRuntime's symbolic shape inference. + + This function uses ONNXRuntime's symbolic shape inference tool which provides + better support for symbolic dimensions and dynamic shapes compared to ONNX's + built-in shape inference. + + :param model: The ONNX model to perform shape inference on. + :param auto_merge: Whether to automatically merge symbolic dimensions when possible. + :return: The ONNX model with inferred shapes. + + .. note:: + Falls back to ONNX's built-in shape inference if ONNXRuntime is not available + or if symbolic shape inference produces incomplete results. + """ + if not ONNXRUNTIME_AVAILABLE: + if config.Config.get_bool('debugprint'): + print("Warning: ONNXRuntime not available, falling back to ONNX shape inference.") + # Fallback to ONNX's built-in shape inference + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + try: + # Use static method API + model = SymbolicShapeInference.infer_shapes( + model, + int_max=2**31 - 1, + auto_merge=auto_merge, + guess_output_rank=False, + verbose=0, + ) + + # Check if shape inference completed successfully for all value_infos + incomplete_shapes = False + for value in model.graph.value_info: + if not _nested_HasField(value, "type.tensor_type.shape"): + incomplete_shapes = True + break + + if incomplete_shapes: + if config.Config.get_bool('debugprint'): + print("Warning: ONNXRuntime symbolic shape inference produced incomplete results, " + "falling back to ONNX shape inference.") + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + return model + except Exception as e: + if config.Config.get_bool('debugprint'): + print(f"Warning: ONNXRuntime symbolic shape inference failed ({e}), " + "falling back to ONNX shape inference.") + import onnx.shape_inference + return onnx.shape_inference.infer_shapes(model, check_type=False, strict_mode=False, data_prop=True) + + +def simplify_onnx_model(model: onnx.ModelProto, auto_merge: bool) -> onnx.ModelProto: + """ + Simplify an ONNX model using onnx-simplifier. + + This function applies various optimizations to the ONNX model including: + - Constant folding + - Dead code elimination + - Shape inference + - Operator fusion (except batch normalization) + + :param model: The ONNX model to simplify. + :param auto_merge: Whether to automatically merge nodes (passed to onnxsim). + :return: The simplified ONNX model. + :raises ImportError: If onnxsim is not installed. + :raises RuntimeError: If onnx-simplifier optimizations fail validation. + + .. note:: + Batch normalization fusion is skipped (skip_fuse_bn=True) to maintain + numerical accuracy and allow separate optimization strategies. + """ + if not ONNXSIM_AVAILABLE: + raise ImportError("onnxsim is required for model simplification. Install with: pip install dace[ml]") + + try: + model, check = onnxsim.simplify(model, skip_fuse_bn=True) + if not check: + raise RuntimeError("onnx-simplifier optimizations failed validation") + return model + except (onnx.checker.ValidationError, ValueError) as e: + # If simplification fails due to validation errors (e.g., missing shape info), + # return the original model + if config.Config.get_bool('debugprint'): + print(f"Warning: ONNX simplification failed with error: {e}. Continuing without simplification.") + return model + + +class ONNXModel: + """ Loads an ONNX model into an SDFG. + + :Example: + First download an ONNX model, such as + `efficientnet `_. + + .. testsetup:: + + import subprocess + model_path = os.path.join("..", "tests", "onnx_files", "efficientnet.onnx") + # Download model + if not os.path.exists(model_path): + subprocess.check_call([ + "wget", + "http://spclstorage.inf.ethz.ch/~rauscho/efficientnet-lite4-11.onnx", + "--output-document={}".format(model_path), + "--no-verbose" + ]) + + + .. testcode:: + + import onnx + import os + import numpy as np + from dace.onnx import ONNXModel + + model_path = os.path.join("..", "tests", "onnx_files", "efficientnet.onnx") + model = onnx.load(model_path) + dace_model = ONNXModel("efficientnet", model) + + test_input = np.random.rand(1, 3, 224, 224).astype(np.float32) + dace_model(test_input) + + """ + + def __init__(self, + name: str, + model: onnx.ModelProto, + cuda: bool = False, + auto_optimize: bool = False, + simplify: bool = False, + onnx_simplify: bool = True, + storage: Optional[dtypes.StorageType] = None, + save_transients: Optional[Dict[str, torch.Tensor]] = None, + auto_merge: bool = False): + """ + :param name: the name for the SDFG. + :param model: the model to import. + :param cuda: if ``True``, the model will be executed on the GPU. + :param simplify: if ``True``, apply simplification transformations after all nodes have been expanded. + :param onnx_simplify: if True, run ONNX-level simplifications such as constant folding and shape inference. + :param auto_optimize: if ``True``, apply automatic optimizations before calling. + :param storage: the storage type of the parameters, inputs and outputs. If None, will be set according to + ``cuda``. + :param save_transients: if not None, save transients to this dict (for debugging). + :param: whether to automatically merge conflicting shapes in symbolic shape inference. + :param auto_merge: whether to automatically merge symbolic shapes in symbolic shape inference. + """ + + onnx.checker.check_model(model) + + # Use temporary files for intermediate model saves + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_original: + onnx.save(model, temp_original.name) + model = infer_shapes_onnx_model(model, auto_merge=auto_merge) + + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_shapes: + onnx.save(model, temp_shapes.name) + + if onnx_simplify: + model = simplify_onnx_model(model, auto_merge) + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=True) as temp_simplified: + onnx.save(model, temp_simplified.name) + + self.do_auto_optimize = auto_optimize + self.model = model + graph: onnx.GraphProto = model.graph + self.save_transients = save_transients + self.sdfg: SDFG = SDFG(name) #: the generated SDFG. + self.sdfg._parent_onnx_model = self + self.cuda = cuda + self.simplify = simplify + self.state: SDFGState = self.sdfg.add_state() #: the state containing the model computation. + + # Add all values to the SDFG, check for unsupported ops + ########################################## + + self.value_infos = {} + + self.inputs: List[str] = [] #: the inputs to the model + self.outputs: List[str] = [] #: the outputs of the model + + if storage is None: + storage = dtypes.StorageType.GPU_Global if self.cuda else dtypes.StorageType.Default + + for value, is_input in chain(zip(graph.input, repeat(True)), zip(graph.output, repeat(False))): + if not value.HasField("name"): + raise ValueError("Got input or output without name") + if is_input: + self.inputs.append(value.name) + else: + self.outputs.append(value.name) + + self.value_infos[value.name] = value + storage = storage + self._add_value_info(value, storage=storage) + + self.sdfg.arg_names = [clean_onnx_name(i) for i in self.inputs] + + for value in graph.value_info: + if not value.HasField("name"): + raise ValueError("Got input or output without name") + if value.name not in self.value_infos: + self.value_infos[value.name] = value + + # add weights + self.weights: Dict[str, torch.Tensor] = {} #: mapping from weight name to array + for init in graph.initializer: + self._add_constant_tensor(init, storage) + + access_nodes = {} + self._idx_to_node = [] + for i, node in enumerate(graph.node): + if not has_onnx_node(node.op_type): + raise ValueError("Unsupported ONNX operator: '{}'".format(node.op_type)) + + # extract the op attributes + op_attributes = { + attribute_proto.name: convert_attribute_proto(attribute_proto) + for attribute_proto in node.attribute + } + + if node.op_type == "Constant": + # Add constants to weights immediately + possible_values = [ + "sparse_value", "value", "value_float", "value_floats", "value_int", "value_ints", "value_string", + "value_strings" + ] + + # do some manual validation here since the node validation will never run + if set(op_attributes).difference(possible_values): + raise ValueError(f"Got unexpected attributes on Constant node " + f"{set(op_attributes).difference(possible_values)}") + + if len(op_attributes) != 1: + raise ValueError("Expected Constant node to have exactly one of its attributes set") + + if len(node.input) != 0 or len(node.output) != 1: + raise ValueError("Expected Constant node to have no inputs and exactly 1 output") + + value_name = next(iter(op_attributes)) + + self._add_constant_tensor((node.output[0], op_attributes[value_name]), storage) + continue + + if node.HasField("name"): + node_name = clean_onnx_name(node.name) + else: + node_name = node.op_type + "_" + str(i) + + # construct the dace node + [opset] = [i for i in model.opset_import if not i.domain] + node_schema = onnx.defs.get_schema(node.op_type, opset.version) + node_version = node_schema.since_version + op_node = get_onnx_node(node.op_type, node_version)(node_name, **op_attributes) + self.state.add_node(op_node) + self._idx_to_node.append(op_node) + + for param_idx, (name, is_input) in chain(enumerate(zip(node.input, repeat(True))), + enumerate(zip(node.output, repeat(False)))): + # Get parameter schema + params = op_node.schema.inputs if is_input else op_node.schema.outputs + params_len = len(params) + + # Determine parameter type and validate + if param_idx >= params_len: + # Variadic parameter beyond schema range + if params[-1].param_type != ONNXParameterType.Variadic: + raise ValueError( + "Expected the last {i_or_o} parameter to be variadic," + " since the {i_or_o} with idx {param_idx} has more parameters than the schema ({params_len})" + .format(i_or_o="input" if is_input else "output", + param_idx=param_idx, + params_len=params_len)) + param_type = ONNXParameterType.Variadic + conn_name = params[-1].name + "__" + str(param_idx - params_len + 1) + else: + param_type = params[param_idx].param_type + if param_type == ONNXParameterType.Variadic: + conn_name = params[param_idx].name + "__0" + else: + conn_name = params[param_idx].name + + # Handle optional parameters + if param_type == ONNXParameterType.Optional and not name: + continue + + # Validate required parameters + if param_type != ONNXParameterType.Optional and not name: + raise ValueError("Required {i_or_o} parameter '{param_name}' is not set".format( + i_or_o="input" if is_input else "output", param_name=params[param_idx].name)) + + # Create array if needed + if clean_onnx_name(name) not in self.sdfg.arrays: + if name not in self.value_infos: + raise ValueError("Could not find array with name '{}'".format(name)) + self._add_value_info(self.value_infos[name]) + + # Get or create access node + if name in access_nodes: + access = access_nodes[name] + else: + access = nodes.AccessNode(clean_onnx_name(name)) + self.state.add_node(access) + access_nodes[name] = access + + data_desc = self.sdfg.arrays[clean_onnx_name(name)] + + # Add connector and edge + if is_input: + if conn_name not in op_node.in_connectors: + assert op_node.add_in_connector(conn_name) + self.state.add_edge(access, None, op_node, conn_name, + dace.Memlet.from_array(clean_onnx_name(name), data_desc)) + else: + if conn_name not in op_node.out_connectors: + assert op_node.add_out_connector(conn_name) + self.state.add_edge(op_node, conn_name, access, None, + dace.Memlet.from_array(clean_onnx_name(name), data_desc)) + + # scalars need to be promoted to arrays so that we can return them from the dace program + # however, this is only for CPU: on GPU, scalars are already pointers + self._promoted_scalars = set() + + # insert copies from outputs to __return arrays + copy_out_state = self.sdfg.add_state_after(self.state, label='copy_out') + new_output_names = [] + for i, output in enumerate(self.outputs): + clean_name = clean_onnx_name(output) + new_output_name = '__return' + if len(self.outputs) > 1: + new_output_name += '_' + str(i) + new_output_names.append(new_output_name) + + desc = copy.deepcopy(self.sdfg.arrays[clean_name]) + if isinstance(desc, dt.Scalar) and not self.cuda: + desc = dt.Array(desc.dtype, (1, )) + self._promoted_scalars.add(new_output_name) + + # insert new descriptor + self.sdfg.arrays[new_output_name] = desc + desc.transient = False + + copy_out_state.add_edge(copy_out_state.add_read(clean_name), None, + copy_out_state.add_write(new_output_name), None, + self.sdfg.make_array_memlet(clean_name)) + + # finally, rename outputs, and fuse states + self.outputs = new_output_names + sdfg_utils.fuse_states(self.sdfg) + + if self.cuda: + self.sdfg.apply_gpu_transformations() + + def _add_constant_tensor(self, tensor: Union[onnx.TensorProto, Tuple[str, np.ndarray]], + storage: dtypes.StorageType): + if isinstance(tensor, tuple): + unclean_name, value = tensor + dtype = dtypes.dtype_to_typeclass(value.dtype.type) + shape = value.shape + np_array = value + else: + if not tensor.HasField("name"): + raise ValueError("Got tensor without name") + + if not tensor.HasField("data_type"): + raise ValueError("Initializer tensor '{}' has no type".format(tensor.name)) + unclean_name = tensor.name + dtype = onnx_tensor_type_to_typeclass(tensor.data_type) + shape = [d for d in tensor.dims] + np_array = numpy_helper.to_array(tensor) + + name = clean_onnx_name(unclean_name) + if unclean_name in self.inputs: + # remove the tensor from inputs since this is a constant + self.inputs.remove(unclean_name) + # note: inputs already have data-descriptors created for them, so + # we skip the below code + elif len(shape) == 0: + # this is a scalar + self.sdfg.add_scalar(name, dtype, storage=storage) + else: + if name not in self.sdfg.arrays: + self.sdfg.add_array(name, shape, dtype, storage=storage, transient=False) + else: + existing_arr = self.sdfg.arrays[name] + if existing_arr.dtype != dtype: + raise ValueError( + "Invalid ONNX model; found two values with name '{}', but different dtypes ({} and {})".format( + name, existing_arr.dtype, dtype)) + if tuple(existing_arr.shape) != tuple(shape): + raise ValueError( + "Invalid ONNX model; found two values with name '{}', but different dimensions ({} and {})". + format(name, existing_arr.shape, shape)) + + # we need to copy here because the weight_arr tensor is not writable + self.weights[unclean_name] = torch.from_numpy(np_array.copy()) + + def _add_value_info(self, value_info: onnx.ValueInfoProto, storage=dtypes.StorageType.Default): + if not value_info.HasField("name"): + raise ValueError("Got value without name") + + name = value_info.name + + if not _nested_HasField(value_info, "type.tensor_type.shape"): + raise ValueError("Value '{}' does not have a shape in this graph." + " Please run shape inference before importing.".format(name)) + + tensor_type = value_info.type.tensor_type + + if not tensor_type.HasField("elem_type"): + raise ValueError("Value '{}' does not have a type in this graph." + " Please run type inference before importing.".format(name)) + + shape = [] + for d in tensor_type.shape.dim: + if d.HasField("dim_value"): + shape.append(d.dim_value) + elif d.HasField("dim_param"): + parsed = pystr_to_symbolic(d.dim_param) + + for sym in parsed.free_symbols: + if clean_onnx_name(str(sym)) not in self.sdfg.symbols: + self.sdfg.add_symbol(clean_onnx_name(str(sym)), stype=int) + parsed = parsed.subs(sym, dace.symbol(clean_onnx_name(str(sym)))) + + shape.append(parsed) + else: + raise ValueError("Value '{}' does not have a shape in this graph." + " Please run shape inference before importing.".format(name)) + transient = name not in self.inputs + if len(shape) == 0: + self.sdfg.add_scalar(clean_onnx_name(name), + dtype=onnx_tensor_type_to_typeclass(tensor_type.elem_type), + transient=transient, + storage=storage) + else: + self.sdfg.add_array(clean_onnx_name(name), + shape=shape, + dtype=onnx_tensor_type_to_typeclass(tensor_type.elem_type), + transient=transient, + storage=storage) + + @property + def clean_weights(self): + return {clean_onnx_name(k): v for k, v in self.weights.items()} + + def compile_and_init(self) -> compiled_sdfg.CompiledSDFG: + """ Compile the SDFG and load parameters into GPU memory. """ + + compiled_sdfg = self.sdfg.compile() + + # copy all parameters to the device + self.initialized_parameters = {} + for name, arr in self.weights.items(): + if clean_onnx_name(name) in compiled_sdfg.sdfg.arrays: + desc = self.sdfg.arrays[clean_onnx_name(name)] + cuda = desc.storage in dace.dtypes.GPU_STORAGES + if type(desc) is dt.Scalar: + self.initialized_parameters[clean_onnx_name(name)] = arr.cuda() if cuda else arr.cpu().numpy()[()] + else: + self.initialized_parameters[clean_onnx_name(name)] = arr.cuda() if cuda else arr + + return compiled_sdfg + + def __call__(self, *args, + **kwargs) -> Union[Union[torch.Tensor, np.ndarray], Tuple[Union[torch.Tensor, np.ndarray]]]: + """ Execute the model. + + :param args: positional arguments to the model. The i-th argument will be passed as the i-th input of the + model. + :param kwargs: named arguments to the model. The passed names should match the names in the ONNX model. + :return: the output of the model (or a tuple of outputs if there are multiple). + """ + + transient_kwargs = {} + if self.save_transients is not None: + for node, parent in self.sdfg.all_nodes_recursive(): + if isinstance(node, nodes.AccessNode): + desc = self.sdfg.arrays[node.data] + if not isinstance(desc, dt.View) and desc.transient: + desc.transient = False + transient_kwargs[node.data] = desc + + if self.do_auto_optimize: + self.auto_optimize() + + compiled = self.compile_and_init() + + inputs, symbols, outputs = self._call_args(args=args, kwargs=kwargs) + + for name, desc in transient_kwargs.items(): + if name in self.initialized_parameters: + transient_kwargs[name] = self.initialized_parameters[name] + self.initialized_parameters.pop(name) + else: + transient_kwargs[name] = create_output_array(symbols, desc, use_torch=True, zeros=True) + self.save_transients[name] = transient_kwargs[name] + + compiled(**inputs, **outputs, **self.initialized_parameters, **symbols, **transient_kwargs) + + # demote scalars we promoted above + for scalar in self._promoted_scalars: + outputs[scalar] = outputs[scalar].reshape(()) + + if len(outputs) == 1: + return next(iter(outputs.values())) + + return tuple(outputs.values()) + + def _call_args(self, + *, + args, + kwargs, + torch_outputs: bool = None) -> Tuple[Dict[str, Any], Dict[str, Any], OrderedDict[str, Any]]: + """ Prepare the arguments for a call. + + This returns 4 dicts; one for each of the following: + 1. the inputs + 3. inferred values for symbols for dynamic dimensions + 4. outputs + + These arguments can be passed to `self.sdfg`. + + :param args: model positional args + :param kwargs: model kwargs + :param torch_outputs: if not None, the outputs will be torch tensors depending on the boolean value. + Otherwise the outputs will be torch tensors only if at least one of the inputs is a + torch tensor. + :return: the tuple of dicts + """ + inputs = kwargs + + # convert the positional args to kwargs + if len(args) > len(self.inputs): + raise ValueError("Expected {} arguments, got {}".format(len(self.inputs), len(args))) + + inputs.update(dict(zip(self.inputs, args))) + + # check that there are no missing inputs + if len(set(self.inputs).difference(inputs)) != 0: + raise ValueError("Missing inputs {}".format(", ".join(set(self.inputs).difference(inputs)))) + + # check that there are no unknown inputs + # NOTE symbols can only be passed as kwargs + if len(set(inputs).difference(self.inputs).difference(self.sdfg.free_symbols)) != 0: + raise ValueError("Unknown inputs {}".format(", ".join(set(inputs).difference(self.inputs)))) + + clean_inputs = {} + for input, arr in inputs.items(): + if input in self.sdfg.free_symbols: + clean_inputs[input] = arr + else: + clean_inputs[clean_onnx_name(input)] = arr + + inferred_symbols = parser.infer_symbols_from_datadescriptor(self.sdfg, { + **clean_inputs, + **self.initialized_parameters + }) + inferred_symbols = {k: int(v) for k, v in inferred_symbols.items()} + + if torch_outputs is None: + torch_outputs = any(self.sdfg.arrays[clean_onnx_name(o)].storage in dace.dtypes.GPU_STORAGES + for o in self.outputs) or any( + isinstance(inp, torch.Tensor) for _, inp in clean_inputs.items()) + + outputs = collections.OrderedDict() + # create numpy arrays for the outputs + for name in self.outputs: + clean_name = clean_onnx_name(name) + outputs[clean_name] = create_output_array(inferred_symbols, + self.sdfg.arrays[clean_name], + use_torch=torch_outputs, + zeros=True) + + # check that there's no overlap + seen = set() + for parameters in [clean_inputs, self.initialized_parameters, outputs, inferred_symbols]: + new_parameters = set(parameters) + assert not seen.intersection(new_parameters) + seen |= new_parameters + + return clean_inputs, inferred_symbols, outputs + + def expand_onnx_nodes(self): + onnx_node_expander(self.sdfg) + + def auto_optimize(self): + auto_opt( + self.sdfg, + self.cuda, + simplify=self.simplify, + # constants have been folded before GPU transforms + fold_constants=False) + + +def create_output_array(inferred_symbols: Dict[str, int], + desc: dt.Data, + use_torch=False, + zeros: bool = False) -> Union[np.ndarray, torch.tensor]: + """ Create the array for an output. This is either a numpy array or a torch tensor depending on `use_torch` + + When `self.force_torch_outputs` is True, the outputs will be tensors. Otherwise, the outputs will be tensors + :param inferred_symbols: the symbols inferred from `infer_symbols_from_datadescriptor`. + :param desc: the data descriptor for the array + :param use_torch: whether to return a numpy array or a torch tensor. + :param zeros: if true init with zeros else empty. + """ + + def eval_dim(dim): + for sym in dim.free_symbols: + dim = dim.subs(sym, inferred_symbols[sym.name]) + return dim + + cuda = desc.storage in dace.dtypes.GPU_STORAGES + if cuda and not use_torch: + raise ValueError("Got use_torch=False, but received a GPU descriptor") + + if isinstance(desc, dt.Scalar): + shape = [] + else: + shape = [eval_dim(d) if type(d) is dace.symbol else d for d in desc.shape] + + if use_torch: + # torch functions don't accept the empty shape, so create shape [1] then reshape to () + if len(shape) == 0: + shape = [1] + + # as_numpy_dtype doesn't seem to work for indexing into the dict + if desc.dtype == dace.pointer(dace.typeclass(None)): + # assuming 64 bit ptrs + dtype = torch.int64 + else: + dtype = numpy_to_torch_dtype_dict[getattr(np, desc.dtype.to_string())] + tens = (torch.zeros if zeros else torch.empty)(*shape, dtype=dtype) + if isinstance(desc, dt.Scalar): + tens = tens.reshape(()) + + return tens.cuda() if cuda else tens + else: + return (np.zeros if zeros else np.empty)(shape, dtype=getattr(np, desc.dtype.to_string())) diff --git a/dace/frontend/tensorflow/__init__.py b/dace/frontend/ml/tensorflow/__init__.py similarity index 100% rename from dace/frontend/tensorflow/__init__.py rename to dace/frontend/ml/tensorflow/__init__.py diff --git a/dace/frontend/tensorflow/tensorflow.py b/dace/frontend/ml/tensorflow/tensorflow.py similarity index 99% rename from dace/frontend/tensorflow/tensorflow.py rename to dace/frontend/ml/tensorflow/tensorflow.py index af71493214..ef6cfdb409 100644 --- a/dace/frontend/tensorflow/tensorflow.py +++ b/dace/frontend/ml/tensorflow/tensorflow.py @@ -17,8 +17,8 @@ from dace.data import Scalar from dace.sdfg.nodes import Tasklet, NestedSDFG from dace.symbolic import symstr, SymExpr -from dace.frontend.tensorflow.winograd import winograd_convolution -from dace.frontend.tensorflow.transformations.redundant_array import (TensorflowRedundantArray) +from .winograd import winograd_convolution +from .transformations.redundant_array import TensorflowRedundantArray try: import tensorflow as tf diff --git a/dace/frontend/tensorflow/transformations/__init__.py b/dace/frontend/ml/tensorflow/transformations/__init__.py similarity index 100% rename from dace/frontend/tensorflow/transformations/__init__.py rename to dace/frontend/ml/tensorflow/transformations/__init__.py diff --git a/dace/frontend/tensorflow/transformations/redundant_array.py b/dace/frontend/ml/tensorflow/transformations/redundant_array.py similarity index 100% rename from dace/frontend/tensorflow/transformations/redundant_array.py rename to dace/frontend/ml/tensorflow/transformations/redundant_array.py diff --git a/dace/frontend/tensorflow/winograd.py b/dace/frontend/ml/tensorflow/winograd.py similarity index 100% rename from dace/frontend/tensorflow/winograd.py rename to dace/frontend/ml/tensorflow/winograd.py diff --git a/dace/frontend/ml/torch/__init__.py b/dace/frontend/ml/torch/__init__.py new file mode 100644 index 0000000000..d1563fac13 --- /dev/null +++ b/dace/frontend/ml/torch/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from .module import DaceModule +from .interface import module + +__all__ = ['DaceModule', 'module'] diff --git a/dace/frontend/ml/torch/interface.py b/dace/frontend/ml/torch/interface.py new file mode 100644 index 0000000000..6dc1f68d1f --- /dev/null +++ b/dace/frontend/ml/torch/interface.py @@ -0,0 +1,89 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Python interface for DaCe PyTorch/Torch integration. + +This module provides decorators and utilities for converting PyTorch modules +to DaCe-accelerated implementations. +""" + +from functools import wraps +from typing import Optional, Tuple, List + +from dace.dtypes import paramdec + + +@paramdec +def module(moduleclass, + dummy_inputs: Optional[Tuple] = None, + cuda: Optional[bool] = None, + training: bool = False, + backward=False, + inputs_to_skip: Optional[List[str]] = None, + onnx_simplify: bool = True, + simplify: bool = True, + auto_optimize: bool = True, + sdfg_name: Optional[str] = None, + compile_torch_extension: bool = True, + debug_transients: bool = False): + """ + Decorator to apply on a definition of a ``torch.nn.Module`` to convert it + to a data-centric module upon construction. + + Example:: + + import dace.ml + import torch.nn as nn + + @dace.ml.module + class MyDecoratedModule(nn.Module): + def forward(self, x): + x = torch.log(x) + x = torch.sqrt(x) + return x + + module_instance = MyDecoratedModule() + module_instance(torch.ones(2)) # tensor([0., 0.]) + + .. Note:: + You must import ``dace.ml`` (not just ``dace``) to use this decorator. + + :param moduleclass: The model to wrap. + :param dummy_inputs: A tuple of tensors to use as input when tracing the model. + :param cuda: If ``True``, the module will execute using CUDA. + If ``None``, it will be detected from the module. + :param training: Whether to use train mode when tracing the model. + :param backward: Whether to enable the backward pass. + :param inputs_to_skip: If provided, a list of inputs to skip computing gradients for + (only relevant when the backward pass is enabled). + :param onnx_simplify: Whether to apply ONNX simplification using onnxsim. + :param simplify: Whether to apply simplification transforms after conversion. + This generally improves performance but can be slow. + :param auto_optimize: Whether to apply automatic optimizations. + :param sdfg_name: The name to give to the SDFG (defaults to moduleclass name). + :param compile_torch_extension: If ``True``, a torch C++ extension will be compiled + and used for this module. Otherwise, a Python ctypes implementation will be used. + :param debug_transients: If ``True``, the module will have all transients as outputs. + """ + wraps(moduleclass) + + def _create(*args, **kwargs): + # Lazy import DaceModule when decorator is actually used + try: + from dace.frontend.ml.torch import DaceModule + except ImportError: + raise ImportError("DaceModule requires PyTorch. Install with: pip install torch") + + return DaceModule(moduleclass(*args, **kwargs), + dummy_inputs=dummy_inputs, + cuda=cuda, + training=training, + backward=backward, + inputs_to_skip=inputs_to_skip, + onnx_simplify=onnx_simplify, + simplify=simplify, + auto_optimize=auto_optimize, + sdfg_name=sdfg_name, + compile_torch_extension=compile_torch_extension, + debug_transients=debug_transients) + + return _create diff --git a/dace/frontend/ml/torch/module.py b/dace/frontend/ml/torch/module.py new file mode 100644 index 0000000000..ba820faf87 --- /dev/null +++ b/dace/frontend/ml/torch/module.py @@ -0,0 +1,581 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" DaCe Python parsing functionality and entry point to Python frontend. """ +from dataclasses import dataclass +import collections +import itertools +import tempfile +import copy +import os +from typing import Any, Callable, Dict, OrderedDict, List, Optional, Set, Sequence, Tuple, Union + +# Try importing ML dependencies +try: + import torch + from torch import Tensor + import torch.nn as nn + from torch.onnx import TrainingMode + TORCH_AVAILABLE = True +except ImportError: + torch = None + Tensor = None + nn = None + TrainingMode = None + TORCH_AVAILABLE = False + +try: + import onnx + ONNX_AVAILABLE = True +except ImportError: + onnx = None + ONNX_AVAILABLE = False + +import dace +from dace import config, data +from dace.codegen import compiled_sdfg +from dace.sdfg import SDFG, nodes +from dace.frontend.python import common as pycommon +from dace.data import find_new_name + +if TORCH_AVAILABLE and ONNX_AVAILABLE: + from dace.libraries.onnx.converters import clean_onnx_name + from dace.libraries.torch import dispatchers + from dace.autodiff import torch as torch_autodiff + from dace.autodiff.library import library as autodiff_library + from dace.frontend.ml.onnx import ONNXModel + from dace.transformation.onnx import auto_optimize_onnx as auto_opt +else: + clean_onnx_name = None + dispatchers = None + torch_autodiff = None + autodiff_library = None + ONNXModel = None + auto_opt = None + +if TORCH_AVAILABLE and ONNX_AVAILABLE: + + def _onnx_delete_initializers(model: onnx.ModelProto, names: Set[str]) -> None: + """ + Delete the given initializers from the given onnx model. + + :param model: The ONNX model to modify. + :param names: Set of initializer names to delete. + :note: Operates in-place. + """ + to_remove = [] + for i, initializer in enumerate(model.graph.initializer): + if initializer.name in names: + to_remove.append(i) + + for i in reversed(to_remove): + model.graph.initializer.pop(i) + + class DaceModule(nn.Module, pycommon.SDFGConvertible): + """ A wrapper that converts a PyTorch ``nn.Module`` to a PyTorch compatible data-centric ``nn.Module``. + + :param module: the model to wrap. + :param dummy_inputs: a tuple of tensors to use as input when tracing ``model``. + :param cuda: if ``True``, the module will execute using CUDA. If ``None``, it will be detected from the + ``module``. + :param training: whether to use train mode when tracing ``model``. + :param backward: whether to enable the backward pass. + :param inputs_to_skip: if provided, a list of inputs to skip computing gradients for. + (only relevant when the backward pass is enabled) + :param onnx_simplify: whether to apply onnx simplification using onnxsim. + :param simplify: whether to apply simplification transforms after conversion (this generally improves performance, + but can be slow). + :param sdfg_name: the name to give to the sdfg (defaults to moduleclass name). + :param auto_optimize: whether to apply automatic optimizations. + :param compile_torch_extension: if True, a torch C++ extension will be compiled and used for this module. + Otherwise, a python ctypes implementation will be used. + :param debug_transients: if True, the module will have all transients as outputs. + + :Example: + >>> from dace.frontend.ml.torch import DaceModule + >>> class MyModule(nn.Module): + ... def forward(self, x): + ... x = torch.log(x) + ... x = torch.sqrt(x) + ... return x + >>> module = MyModule() + >>> module(torch.ones(2)) + tensor([0., 0.]) + >>> dace_module = DaceModule(module) + >>> dace_module(torch.ones(2)) + tensor([0., 0.]) + """ + + def __init__(self, + module: nn.Module, + dummy_inputs: Optional[Tuple[torch.Tensor, ...]] = None, + cuda: Optional[bool] = None, + training: bool = False, + backward: bool = False, + inputs_to_skip: Optional[List[str]] = None, + onnx_simplify: bool = True, + simplify: bool = True, + auto_optimize: bool = False, + debug_transients: bool = False, + compile_torch_extension: bool = True, + sdfg_name: Optional[str] = None): + + super(DaceModule, self).__init__() + + self.backward = backward + self.model = module + self.dace_model: Optional[ONNXModel] = None + self.training = training + self.sdfg: Optional[SDFG] = None + self.use_cuda = cuda + self.sdfg_name = sdfg_name or type(module).__name__ + self.auto_optimize = auto_optimize + self.onnx_simplify = onnx_simplify + self.simplify = simplify + self.debug_transients = debug_transients + self.compile_torch_extension = compile_torch_extension + self.inputs_to_skip = inputs_to_skip or [] + + self.function = None + + #: hooks that are executed after onnx graph is imported to an SDFG + self.post_onnx_hooks: OrderedDict[str, Callable[[DaceModule], None]] = collections.OrderedDict() + + #: hooks that are executed after the backpropagation sdfg has been created + self.post_autodiff_hooks: OrderedDict[str, Callable[[SDFG, SDFG], None]] = collections.OrderedDict() + + #: hooks that are executed after the sdfg is compiled + self.post_compile_hooks: OrderedDict[str, Callable[[compiled_sdfg.CompiledSDFG], + None]] = collections.OrderedDict() + # setup debug hook + if self.debug_transients: + + def transients_outputs(module): + for state in module.sdfg.nodes(): + for node in state.nodes(): + if (isinstance(node, nodes.AccessNode) and node.desc(module.sdfg).transient + and not isinstance(node.desc(module.sdfg), data.Scalar)): + if "mean" not in node.data and "std" not in node.data: + module.dace_model.outputs.append(node.data) + node.desc(module.sdfg).transient = False + + self.prepend_post_onnx_hook("make_transients_outputs", transients_outputs) + + # setup optimization hooks + if self.auto_optimize: + if self.backward: + + def auto_optimize_backward(fwd_sdfg, bwd_sdfg): + auto_opt(fwd_sdfg, self.use_cuda, simplify=self.simplify) + auto_opt(bwd_sdfg, self.use_cuda, simplify=self.simplify) + + self.append_post_autodiff_hook("auto_optimize", auto_optimize_backward) + else: + self.append_post_onnx_hook( + "auto_optimize", lambda dace_module: auto_opt( + dace_module.dace_model.sdfg, self.use_cuda, simplify=self.simplify)) + elif self.simplify: + if self.backward: + + def simplify_hook(fwd_sdfg, bwd_sdfg): + fwd_sdfg.simplify() + bwd_sdfg.simplify() + + self.append_post_autodiff_hook("simplify", simplify_hook) + else: + self.append_post_onnx_hook("simplify", lambda dace_module: dace_module.sdfg.simplify()) + + if dummy_inputs is not None: + self.function = self._initialize_sdfg(dummy_inputs) + + def reset_sdfg(self) -> None: + """Clear the SDFG so that optimizations are reapplied.""" + self.function = None + + def _detect_cuda_usage(self, dummy_inputs) -> bool: + """ + Detect whether CUDA should be used based on inputs and model parameters. + + :param dummy_inputs: Tuple of tensors to check. + :return: True if CUDA should be used, False otherwise. + """ + try: + module_is_cuda = next(iter(dummy_inputs)).is_cuda + except StopIteration: + module_is_cuda = False + + if not module_is_cuda: + # check the parameters + try: + module_is_cuda = next(self.model.parameters()).is_cuda + except StopIteration: + module_is_cuda = False + return module_is_cuda + + def prepend_post_onnx_hook(self, name: str, func: Callable[["DaceModule"], None]) -> None: + """ + Add a hook to be executed after ONNX graph import, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after ONNX import. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_onnx_hooks) + self.post_onnx_hooks[name] = func + self.post_onnx_hooks.move_to_end(name, last=False) + + def append_post_onnx_hook(self, name: str, func: Callable[["DaceModule"], None]) -> None: + """ + Add a hook to be executed after ONNX graph import, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after ONNX import. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_onnx_hooks) + self.post_onnx_hooks[name] = func + + def prepend_post_autodiff_hook(self, name: str, func: Callable[[SDFG, SDFG], None]) -> None: + """ + Add a hook to be executed after autodiff, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after autodiff, receiving forward and backward SDFGs. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_autodiff_hooks) + self.post_autodiff_hooks[name] = func + self.post_autodiff_hooks.move_to_end(name, last=False) + + def append_post_autodiff_hook(self, name: str, func: Callable[[SDFG, SDFG], None]) -> None: + """ + Add a hook to be executed after autodiff, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after autodiff, receiving forward and backward SDFGs. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_autodiff_hooks) + self.post_autodiff_hooks[name] = func + + def prepend_post_compile_hook(self, name: str, func: Callable[[compiled_sdfg.CompiledSDFG], None]) -> None: + """ + Add a hook to be executed after compilation, at the beginning of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after compilation, receiving the compiled SDFG. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_compile_hooks) + self.post_compile_hooks[name] = func + self.post_compile_hooks.move_to_end(name, last=False) + + def append_post_compile_hook(self, name: str, func: Callable[[compiled_sdfg.CompiledSDFG], None]) -> None: + """ + Add a hook to be executed after compilation, at the end of the hook list. + + :param name: Name of the hook (will be made unique if necessary). + :param func: Callable to execute after compilation, receiving the compiled SDFG. + """ + if self.function is not None: + if config.Config.get_bool('debugprint'): + print(f"Warning: Added a hook after the model was already initialized. This hook " + f"(with name {name}) will not be executed!") + name = find_new_name(name, self.post_compile_hooks) + self.post_compile_hooks[name] = func + + def _initialize_sdfg(self, dummy_inputs): + """ + Initialize the SDFG by converting the PyTorch module to ONNX and then to DaCe. + + :param dummy_inputs: Tuple of tensors to use for tracing. + :return: Forward function to be called during execution. + """ + # determine whether we are using CUDA + if self.use_cuda is None: + self.use_cuda = self._detect_cuda_usage(dummy_inputs) + + if self.use_cuda: + self.model = self.model.cuda() + + # TODO change to StringIO if not too big + with tempfile.TemporaryDirectory() as dir_name: + export_name = os.path.join(dir_name, "export.onnx") + + # save the state of the model, and restore it after tracing + state = copy.deepcopy(self.state_dict()) + torch.onnx.export( + self.model, + dummy_inputs, + export_name, + verbose=config.Config.get_bool('debugprint'), + # Some models will require training even when we don't want to train: + # when training is set to EVAL, pytorch currently performs an optimization pass ("onnx_eval_peephole") + # that renames weights and thus breaks the model in some settings. + training=(TrainingMode.TRAINING if self.training else TrainingMode.EVAL), + opset_version=18, + export_params=not self.backward, + # pytorch constant folding will add new unnamed inputs to the graph and remove some of the + # named parameters of the model: this means that we can't match with the state dict + # anymore, so we disable this. Our CF is more flexible. + do_constant_folding=False, + keep_initializers_as_inputs=True, + dynamo=False) + self.load_state_dict(state) + onnx_model_exported = onnx.load(export_name) + + # Remove buffers and parameters from initializers + # they should already be in the inputs (from the pytorch exporter) + # this prevents onnx tools from messing with parameters + input_names = set() + for name, _ in itertools.chain(self.named_parameters(), self.named_buffers()): + # pytorch adds a "model." prefix here that isn't in the onnx export; + # remove it + if not name.startswith("model."): + raise ValueError("Expected parameter names to start with 'model.'") + input_names.add(name[6:]) + + # save the parameters as they are now for later access + self._exported_parameters = dict( + (n, p) for n, p in itertools.chain(self.model.named_parameters(), self.model.named_buffers())) + + _onnx_delete_initializers(onnx_model_exported, input_names) + + # load using importer + dace_model = ONNXModel(self.sdfg_name, + onnx_model_exported, + onnx_simplify=self.onnx_simplify, + cuda=self.use_cuda, + auto_optimize=self.auto_optimize) + self.sdfg = dace_model.sdfg + self.dace_model = dace_model + + self.sdfg.validate() + + for _, hook in self.post_onnx_hooks.items(): + hook(self) + + # choose the backend that will generate the function to call during + # forward + if self.compile_torch_extension: + function_generator = dispatchers.register_and_compile_torch_extension + else: + function_generator = dispatchers.get_ctypes_dispatcher + + if self.backward: + + # Determine what grads we need + # For now: we want gradients for all inputs that are not pytorch buffers + named_buffers = {n for n, _ in self.model.named_buffers()} + required_gradients = [ + clean_onnx_name(name) for name in self.dace_model.inputs + if name not in named_buffers and name not in self.inputs_to_skip + ] + named_parameters = dict(self.model.named_parameters()) + required_gradients.extend( + clean_onnx_name(name) for name, param in named_parameters.items() if param.requires_grad) + required_gradients = list(set(required_gradients)) + + self.forward_sdfg, self.backward_sdfg, self._ad_result, self._ad_inp_arrs = torch_autodiff.make_backward_function( + dace_model, required_gradients) + + for _, hook in self.post_autodiff_hooks.items(): + hook(self.forward_sdfg, self.backward_sdfg) + self.compiled_function = function_generator(self, dummy_inputs) + else: + self.compiled_function = function_generator(self, dummy_inputs) + + # order the parameters + parameters_to_pass = self._call_params() + + def forward(*args): + return self.compiled_function.function(*self.compiled_function.ptr, *args, *parameters_to_pass) + + return forward + + def _call_params(self) -> Tuple[Union[Tensor, nn.parameter.Parameter], ...]: + """ + Get the parameters that we need to pass to the model, in the correct order. + + :return: Tuple of parameters and buffers in the order expected by the SDFG. + """ + # self.dace_model.inputs contains the buffers, parameters and the inputs. + # We only want the parameters and buffers + model_inputs = self.dace_model.inputs + + # find the index of the first input that is a parameter or buffer + start_idx = 0 + while start_idx < len(model_inputs) and model_inputs[start_idx] not in self._exported_parameters: + start_idx += 1 + + return tuple(self._exported_parameters[i] for i in model_inputs[start_idx:]) + + def forward(self, *actual_inputs): + """ + Execute the forward pass using the traced module. + + :param actual_inputs: Input tensors to the model. + :return: Output tensors from the model. + """ + if self.function is None: + self.function = self._initialize_sdfg(actual_inputs) + + return self.function(*actual_inputs) + + # SDFGConvertible methods: + # used when the model is called in a DaceProgram. + ################################################# + + def __sdfg__(self, *args): + """ + Get the SDFG representation of this module (SDFGConvertible interface). + + :param args: Arguments (currently unused). + :return: The SDFG representation. + :raises ValueError: If the model has not been initialized yet. + """ + if self.sdfg is None: + raise ValueError("Using a PyTorch model in a DaceProgram requires that the model is initialized first. " + "Either call this model using some inputs, or pass 'dummy_inputs' to the constructor.") + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + if param.requires_grad: + autodiff_library.ParameterArray.make_parameter(self.sdfg, onnx_name) + return self.sdfg + + def _add_gradient_buffers(self) -> List[str]: + """ + Allocate gradient buffers for all parameters, and add their descriptors to the SDFG. + + :return: a list of the sdfg array names of the gradient buffers + """ + + assert self.sdfg is not None + if hasattr(self, '_gradient_buffers'): + return self._gradient_buffers + + buffers = [] + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + desc = self.sdfg.arrays[onnx_name] + + if param.requires_grad: + # allocate gradient buffer + param.grad = torch.empty_like(param.data) + + # add gradient buffer descriptor to sdfg + autodiff_library.ParameterArray.make_parameter(self.sdfg, onnx_name) + desc: autodiff_library.ParameterArray = self.sdfg.arrays[onnx_name] + grad_name = desc.add_gradient_buffer(self.sdfg, onnx_name) + grad_desc = self.sdfg.arrays[grad_name] + grad_desc.transient = False + buffers.append(grad_name) + self._gradient_buffers = buffers + return buffers + + def __sdfg_signature__(self): + """ + Get the SDFG signature (SDFGConvertible interface). + + :return: Tuple of (input names, output names). + :raises ValueError: If the SDFG has not been generated yet. + """ + if self.dace_model is None: + raise ValueError("Can't determine signature before SDFG is generated.") + inputs = [clean_onnx_name(name) for name in self.dace_model.inputs] + grad_buffers = self._add_gradient_buffers() + inputs.extend(grad_buffers) + + return inputs, [] + + @staticmethod + def _tensor_from_param(param) -> Tensor: + """ + Extract tensor from parameter while preserving requires_grad flag. + + :param param: PyTorch parameter. + :return: Tensor with correct requires_grad setting. + """ + t = param.data + # Accessing .data on a Parameter resets the requires_grad flag + t.requires_grad = param.requires_grad + return t + + def __sdfg_closure__(self, reevaluate: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + """ + Get the SDFG closure (SDFGConvertible interface). + + :param reevaluate: Optional dictionary for reevaluation (unused). + :return: Dictionary mapping parameter names to their tensor values. + """ + result = {} + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + result[onnx_name] = self._tensor_from_param(param) + if param.requires_grad: + grad_name = self.sdfg.arrays[onnx_name].gradient + assert grad_name, "Expected gradient descriptor to be present" + assert param.grad is not None, "Expected gradient buffer to be allocated" + result[grad_name] = param.grad + + return result + + def closure_resolver(self, + constant_args: Dict[str, Any], + given_args: Set[str], + parent_closure: Optional[pycommon.SDFGClosure] = None) -> pycommon.SDFGClosure: + """ + Resolve closure for SDFG execution (SDFGConvertible interface). + + :param constant_args: Constant arguments. + :param given_args: Arguments already provided. + :param parent_closure: Optional parent closure. + :return: SDFGClosure object containing closure arrays. + """ + assert self.sdfg is not None, "SDFG must be initialized before resolving closure" + result = pycommon.SDFGClosure() + + class TensorClosure: + """Helper class to wrap tensor access in a callable.""" + + def __init__(self, t): + self.t = t + + def __call__(self): + return self.t + + for name, param in self._exported_parameters.items(): + onnx_name = clean_onnx_name(name) + desc = self.sdfg.arrays[onnx_name] + + if param.requires_grad: + # the gradient was already added when __sdfg_signature__ was called earlier + assert desc.gradient, "Expected gradient descriptor to be present" + grad_name = desc.gradient + # also add the gradient to the closure, because we need to write to it + result.closure_arrays[grad_name] = (grad_name, self.sdfg.arrays[grad_name], + TensorClosure(param.grad), False) + + result.closure_arrays[onnx_name] = (name, desc, TensorClosure(self._tensor_from_param(param)), False) + return result + +else: + # Stub class when ML dependencies are not available + class DaceModule: + """Stub class for DaceModule when PyTorch and ONNX are not installed.""" + + def __init__(self, *args, **kwargs): + raise ImportError("DaceModule requires PyTorch and ONNX. Install with: pip install dace[ml]") diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 446079d8f1..18943d6799 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -3,7 +3,7 @@ import inspect from functools import wraps -from typing import Any, Callable, Deque, Dict, Generator, Optional, Tuple, TypeVar, Union, overload +from typing import Any, Callable, Deque, Dict, Generator, Optional, Tuple, TypeVar, Union, overload, TYPE_CHECKING from dace import dtypes from dace.dtypes import paramdec diff --git a/dace/frontend/python/replacements/torch_autodiff.py b/dace/frontend/python/replacements/torch_autodiff.py new file mode 100644 index 0000000000..d110046be9 --- /dev/null +++ b/dace/frontend/python/replacements/torch_autodiff.py @@ -0,0 +1,163 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Integration with the dace python frontend +""" + +from contextlib import contextmanager +from typing import Optional, Union, Sequence +import itertools +import warnings + +import torch +import torch.autograd + +from dace import SDFG, SDFGState, config, data +import dace.sdfg.sdfg +from dace.transformation import optimizer +from dace.frontend.python import common +from dace.frontend.common import op_repository +from dace.frontend.python import newast +from dace.transformation.passes.fusion_inline import InlineControlFlowRegions +from dace.data import find_new_name +from dace.sdfg.utils import expand_nodes +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.autodiff import analysis as autodiff_analysis + +from dace.autodiff.library.library import ParameterArray, BackwardPass + +TensorOrTensors = Union[str, Sequence[str]] + + +@op_repository.replaces('torch.autograd.backward') +def backward(pv: newast.ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + tensors: TensorOrTensors, + grads: Optional[TensorOrTensors] = None): + """ + Adds a backward pass node to the SDFG. + + This function analyses the dependency tree of the tensors and computes + gradients for each Parameter that was used to compute the tensors. + """ + + # First, remove function call regions + transformation = InlineControlFlowRegions() + transformation.set_opts({ + 'no_inline_function_call_regions': False, + 'no_inline_named_regions': False, + 'no_inline_loops': True, + 'no_inline_conditional': True + }) + transformation.apply_pass(sdfg, {}) + + if isinstance(tensors, str): + tensors = [tensors] + + if isinstance(grads, str): + grads = [grads] + + if grads is None: + grads = [] + # when the tensors are scalars, we can implicity create the grads with ones + for tensor in tensors: + tensor_desc = sdfg.arrays[tensor] + if tensor_desc.total_size == 1: + constant_name = sdfg._find_new_name("one") + desc = data.Scalar(tensor_desc.dtype, transient=True, storage=tensor_desc.storage) + sdfg.add_constant(constant_name, 1, dtype=desc) + sdfg.arrays[constant_name] = desc + grads.append(constant_name) + else: + raise common.DaceSyntaxError(pv, None, "grad can be implicitly created only for scalar outputs") + + if len(grads) != len(tensors): + raise common.DaceSyntaxError(pv, None, "grads and tensors must correspond, but they were not the same length") + + for grad, tensor in zip(grads, tensors): + if grad not in sdfg.arrays and grad not in sdfg.constants_prop: + raise common.DaceSyntaxError(pv, None, "Gradient {} is not an array".format(grad)) + if tensor not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Tensor {} is not an array".format(tensor)) + + grad_desc = sdfg.arrays[grad] if grad in sdfg.arrays else sdfg.constants_prop[grad][0] + + if not iterables_equal(grad_desc.shape, sdfg.arrays[tensor].shape): + raise common.DaceSyntaxError(pv, None, + "Gradient {} and tensor {} have different shapes".format(grad, tensor)) + + given_gradients = dict(zip(grads, tensors)) + + bwd_node = BackwardPass('backward', + inputs=set(itertools.chain(tensors, grads)), + outputs=set(), + given_gradients=given_gradients) + state.add_node(bwd_node) + + for inp in itertools.chain(tensors, grads): + state.add_edge(state.add_read(inp), None, bwd_node, inp, sdfg.make_array_memlet(inp)) + + # determine what grdaients to compute + dependencies = autodiff_analysis.dependency_analysis(sdfg) + + to_compute = { + dependency + for tensor in tensors + for dependency in dependencies[tensor] if isinstance(sdfg.arrays[dependency], ParameterArray) + } + + for param in to_compute: + param_desc: ParameterArray = sdfg.arrays[param] + grad_name = param_desc.add_gradient_buffer(sdfg, param) + + conn_name = find_new_name(grad_name, bwd_node.out_connectors) + bwd_node.required_gradients[param] = conn_name + bwd_node.add_out_connector(conn_name) + write_an = state.add_write(grad_name) + write_an.setzero = True + state.add_edge(bwd_node, conn_name, write_an, None, sdfg.make_array_memlet(grad_name)) + + +@op_repository.replaces_attribute('ParameterArray', 'grad') +def grad(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> str: + """ + Returns the name of the gradient buffer of the given array. + + The Array must have been marked as requires_grad_ using + ``arr.requires_grad_()``, otherwise there will be an error + """ + + if arr not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Array {} is not defined".format(arr)) + desc = sdfg.arrays[arr] + if not isinstance(desc, ParameterArray): + raise common.DaceSyntaxError( + pv, None, "Called .grad on an Array that was not a Parameter. Convert it to a parameter " + " first using .requires_grad_()") + + return desc.gradient + + +@op_repository.replaces_method('Array', 'requires_grad_') +@op_repository.replaces_method('Scalar', 'requires_grad_') +def requires_grad_(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str): + """ + Converts a array to a ParameterArray. This creates a descriptor for + the gradient buffer for this array. + """ + + if self not in sdfg.arrays: + raise common.DaceSyntaxError(pv, None, "Array {} is not defined".format(self)) + ParameterArray.make_parameter(sdfg, self) + + +@op_repository.replaces_method('Array', 'backward') +@op_repository.replaces_method('Scalar', 'backward') +def backward_method(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str, grad: Optional[str] = None): + """ + Alias for ``torch.autograd.backward(self)`` + """ + backward(pv, sdfg, state, self, grad) + + +dace.hooks.register_sdfg_call_hook(before_hook=lambda sdfg: expand_nodes(sdfg, lambda n: isinstance(n, BackwardPass))) diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index 6a568f6e4a..42a5c3287b 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -181,7 +181,7 @@ def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]: }, } - if sAM == 1: + if sAM == 1 and sAK != 1: optA = 'm' elif sAK == 1: optA = 'k' diff --git a/dace/libraries/onnx/__init__.py b/dace/libraries/onnx/__init__.py new file mode 100644 index 0000000000..449c7913e8 --- /dev/null +++ b/dace/libraries/onnx/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe ONNX Integration Library. + +This module provides comprehensive support for importing and executing ONNX models +in DaCe. It enables: + +- Importing ONNX models and converting them to DaCe SDFGs +- Implementing ONNX operations as DaCe library nodes +- Automatic shape inference for dynamic models +- Multiple implementation strategies (pure, optimized, etc.) + +Main Components: +- ONNXModel: Main class for importing and manipulating ONNX models +- ONNXOp: Base class for ONNX operation nodes in SDFGs +- Schema system: Type checking and validation for ONNX operations + +The library is registered with DaCe and uses 'pure' as the default implementation +strategy for ONNX operations. +""" + +from dace.library import register_library, _DACE_REGISTERED_LIBRARIES + +try: + # Import schema and node utilities (nodes are lazy-loaded via __getattr__) + from .schema import onnx_representation, ONNXAttributeType, ONNXAttribute, ONNXTypeConstraint, ONNXParameterType, ONNXSchema, ONNXParameter + from .nodes import get_onnx_node, has_onnx_node + + register_library(__name__, "dace.libraries.onnx") + _DACE_REGISTERED_LIBRARIES["dace.libraries.onnx"].default_implementation = "pure" + + ONNX_AVAILABLE = True + + def __getattr__(name): + """Lazy attribute access for ONNX node classes, ONNXModel, and utilities.""" + if name == 'ONNXModel': + from dace.frontend.ml.onnx import ONNXModel as _ONNXModel + return _ONNXModel + if name == 'parse_variadic_param': + from .nodes.node_utils import parse_variadic_param as _parse_variadic_param + return _parse_variadic_param + if name.startswith('ONNX'): + # Initialize registry and get the node class + from .nodes.onnx_op_registry import _initialize_onnx_registry + _initialize_onnx_registry() + from .nodes import onnx_op_registry + if hasattr(onnx_op_registry, name): + return getattr(onnx_op_registry, name) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + +except ImportError: + # ONNX library not available + ONNXModel = None + onnx_representation = None + ONNXAttributeType = None + ONNXAttribute = None + ONNXTypeConstraint = None + ONNXParameterType = None + ONNXSchema = None + ONNXParameter = None + ONNX_AVAILABLE = False diff --git a/dace/libraries/onnx/converters.py b/dace/libraries/onnx/converters.py new file mode 100644 index 0000000000..5ba326934a --- /dev/null +++ b/dace/libraries/onnx/converters.py @@ -0,0 +1,247 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Type conversion utilities for ONNX-DaCe integration. + +This module provides conversion functions between ONNX and DaCe type systems: +- Converting ONNX protobuf types to DaCe types +- Converting DaCe types to ONNX representation +- Handling ONNX AttributeProto conversions +- Type validation and name sanitization + +Key Functions: +- convert_onnx_proto: Convert ONNX protobuf objects to Python/DaCe types +- onnx_tensor_type_to_typeclass: Convert ONNX tensor types to DaCe typeclasses +- typeclass_to_onnx_str: Convert DaCe types to ONNX string representation +- clean_onnx_name: Sanitize ONNX names for valid DaCe identifiers +""" + +import re +from typing import Union + +import onnx +from dace import config, dtypes as dt +from dace.dtypes import typeclass +from onnx.numpy_helper import to_array + + +def get_proto_attr(proto, name: str): + """Safely access a protobuf attribute with encoding validation. + + This function provides defensive checks against encoding issues when accessing + protobuf attributes. Python's getattr expects strings, but protobuf uses UTF-8. + + :param proto: The protobuf object to access. + :param name: The attribute name to retrieve (must be ASCII). + :return: The value of the requested attribute. + :raises ValueError: If the name is not ASCII-encodable. + + .. note:: + HasField checks may break in proto3, but ONNX doesn't use proto3 yet. + """ + + def is_ascii(s: str) -> bool: + """Check if a string is ASCII-encodable.""" + try: + s.encode('ascii') + except UnicodeEncodeError: + return False + else: + return True + + if not is_ascii(name): + raise ValueError( + f"Attempted to access non-ASCII property name '{name}' on protobuf {proto} (type {type(proto)}). " + "Please open an issue") + + return getattr(proto, name) + + +def convert_onnx_proto(attribute): + from dace.libraries.onnx.schema import ONNXAttributeType, _KNOWN_ONNX_PROTOS, ONNXParameterType + + if type(attribute) in _KNOWN_ONNX_PROTOS: + return _KNOWN_ONNX_PROTOS[type(attribute)].from_onnx_proto(attribute) + + # Check ONNX enum types BEFORE basic types, because ONNX enums derive from + # IntEnum and would incorrectly match isinstance(attribute, int) + if type(attribute) is onnx.defs.OpSchema.FormalParameterOption: + if attribute == onnx.defs.OpSchema.FormalParameterOption.Single: + return ONNXParameterType.Single + elif attribute == onnx.defs.OpSchema.FormalParameterOption.Optional: + return ONNXParameterType.Optional + elif attribute == onnx.defs.OpSchema.FormalParameterOption.Variadic: + return ONNXParameterType.Variadic + else: + raise NotImplementedError( + "Only single, optional and variadic formal parameters are supported, got".format(attribute)) + + if type(attribute) is onnx.defs.OpSchema.AttrType: + if attribute == onnx.defs.OpSchema.AttrType.FLOAT: + return ONNXAttributeType.Float + elif attribute == onnx.defs.OpSchema.AttrType.FLOATS: + return ONNXAttributeType.Floats + elif attribute == onnx.defs.OpSchema.AttrType.INT: + return ONNXAttributeType.Int + elif attribute == onnx.defs.OpSchema.AttrType.INTS: + return ONNXAttributeType.Ints + elif attribute == onnx.defs.OpSchema.AttrType.STRING: + return ONNXAttributeType.String + elif attribute == onnx.defs.OpSchema.AttrType.STRINGS: + return ONNXAttributeType.Strings + elif attribute == onnx.defs.OpSchema.AttrType.TENSOR: + return ONNXAttributeType.Tensor + else: + if config.Config.get_bool('debugprint'): + print("Got unsupported attribute type {}".format(attribute)) + return ONNXAttributeType.Unsupported + + if type(attribute) is onnx.AttributeProto: + return convert_attribute_proto(attribute) + + # Check basic Python types after ONNX enums (must be after enum checks) + if isinstance(attribute, (int, str, bool, float)): + return attribute + + raise NotImplementedError("No conversion implemented for {} (type {})".format(attribute, type(attribute))) + + +def convert_attribute_proto(proto): + # we cache the reverse map as an attribute of the method + if hasattr(convert_attribute_proto, "inv_map"): + inv_map = convert_attribute_proto.inv_map + else: + inv_map = {} + for k, v in onnx.AttributeProto.AttributeType.items(): + if k == "FLOAT": + inv_map[v] = lambda attr: get_proto_attr(attr, "f") + elif k == "FLOATS": + inv_map[v] = lambda attr: list(get_proto_attr(attr, "floats")) + elif k == "INT": + inv_map[v] = lambda attr: get_proto_attr(attr, "i") + elif k == "INTS": + inv_map[v] = lambda attr: list(get_proto_attr(attr, "ints")) + elif k == "STRING": + inv_map[v] = lambda attr: get_proto_attr(attr, "s").decode('utf-8') + elif k == "STRINGS": + inv_map[v] = lambda attr: list(map(lambda x: x.decode('utf-8'), get_proto_attr(attr, "strings"))) + elif k == "TENSOR": + inv_map[v] = lambda attr: to_array(get_proto_attr(attr, "t")) + + convert_attribute_proto.inv_map = inv_map + + onnx_type = get_proto_attr(proto, "type") + + if onnx_type == 0: + # in case of undefined return None + return None + + if onnx_type not in inv_map: + type_str = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}[onnx_type] + raise NotImplementedError( + "Only FLOAT, FLOATS, INT, INTS, STRING, STRINGS and TENSOR attributes are supported, got attribute with type {}" + .format(type_str)) + + return inv_map[onnx_type](proto) + + +ONNX_DTYPES_TO_DACE_TYPE_CLASS = { + 'bool': dt.bool, + 'int8': dt.int8, + 'int16': dt.int16, + 'int32': dt.int32, + 'int64': dt.int64, + 'uint8': dt.uint8, + 'uint16': dt.uint16, + 'uint32': dt.uint32, + 'uint64': dt.uint64, + 'float16': dt.float16, + 'float': dt.float32, + 'double': dt.float64, + 'complex64': dt.complex64, + 'complex128': dt.complex128, +} + + +def typeclass_to_onnx_tensor_type_int(dtype: typeclass) -> int: + # we cache the reverse map as an attribute of the method + if not hasattr(typeclass_to_onnx_tensor_type_int, "inv_map"): + typeclass_to_onnx_tensor_type_int.inv_map = { + v: getattr(onnx.TensorProto.DataType, k.upper()) + for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items() + } + + return typeclass_to_onnx_tensor_type_int.inv_map[dtype] + + +def onnx_tensor_type_to_typeclass(elem_type: int) -> typeclass: + # we cache the reverse map as an attribute of the method + if hasattr(onnx_tensor_type_to_typeclass, "inv_map"): + inv_map = onnx_tensor_type_to_typeclass.inv_map + else: + k: str + v: int + inv_map = {} + for k, v in onnx.TensorProto.DataType.items(): + if k.lower() in ONNX_DTYPES_TO_DACE_TYPE_CLASS: + inv_map[v] = ONNX_DTYPES_TO_DACE_TYPE_CLASS[k.lower()] + + onnx_tensor_type_to_typeclass.inv_map = inv_map + + if elem_type not in inv_map: + raise ValueError("Got unsupported ONNX tensor type: {}".format({ + v: k + for k, v in onnx.TensorProto.DataType.items() + }[elem_type])) + + return inv_map[elem_type] + + +def typeclass_to_onnx_str(dtype: typeclass) -> str: + # we cache the reverse map as an attribute of the method + if hasattr(typeclass_to_onnx_str, "inv_map"): + inv_map = typeclass_to_onnx_str.inv_map + else: + inv_map = {v: k for k, v in ONNX_DTYPES_TO_DACE_TYPE_CLASS.items()} + + if dtype not in inv_map: + raise ValueError("Attempted to convert unsupported dace type to ONNX type: {}".format(dtype)) + + return inv_map[dtype] + + +def onnx_type_str_to_typeclass(onnx_str) -> Union[typeclass, None]: + """Converts an onnx type string, like tensor(float16) to a dace typeclass""" + + results = re.findall(r"^tensor\((.+)\)", onnx_str) + if len(results) != 1 or results[0] not in ONNX_DTYPES_TO_DACE_TYPE_CLASS: + # we return None here, these types will be filtered out later + return None + + return ONNX_DTYPES_TO_DACE_TYPE_CLASS[str(results[0])] + + +def clean_onnx_name(name: str) -> str: + """Sanitize an ONNX name to make it a valid DaCe identifier. + + This function transforms ONNX names that may contain invalid characters + or patterns into valid DaCe identifiers by: + + - Prefixing names starting with digits with "ONNX_" + - Replacing special characters with their textual equivalents + + :param name: The ONNX name to sanitize. + :return: A valid DaCe identifier based on the ONNX name. + + Example:: + + >>> clean_onnx_name("123_layer") + 'ONNX_123_layer' + >>> clean_onnx_name("my.tensor:0") + 'myDOTtensorCOLON0' + """ + # If the first character is a digit, add the ONNX_ prefix + if re.match("^[0-9]", name): + name = f"ONNX_{name}" + + # Replace special characters with their textual equivalents + return (name.replace(".", "DOT").replace(":", "COLON").replace("/", "SLASH").replace("-", "DASH")) diff --git a/dace/libraries/onnx/forward_implementation_abc.py b/dace/libraries/onnx/forward_implementation_abc.py new file mode 100644 index 0000000000..5f13a65cf8 --- /dev/null +++ b/dace/libraries/onnx/forward_implementation_abc.py @@ -0,0 +1,105 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Abstract Base Class for ONNX Operation Forward Implementations. + +This module defines the interface that all ONNX operation implementations must +follow in DaCe. It uses a registry pattern to allow multiple implementations +for each ONNX operation, enabling: + +- Pure Python implementations for correctness +- Optimized implementations for performance +- Hardware-specific implementations (CPU, GPU, FPGA) +- Custom user-provided implementations + +The ONNXForward ABC provides: +- Registration mechanism via @make_registry decorator +- Implementation selection based on applicability +- Expansion of ONNX ops to DaCe SDFG nodes + +Implementation Registration: + Implementations register themselves by inheriting from ONNXForward and + using the @op_implementation decorator with: + - `op`: ONNX operation name (e.g., "Conv", "MatMul") + - `name`: Implementation name (e.g., "pure", "optimized") + +Example: + @op_implementation(op="MatMul", name="pure") + class PureMatMul(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Implementation here + pass +""" + +import abc +import typing + +from dace import SDFG, SDFGState +from dace.registry import make_registry +from dace.sdfg.nodes import Node + +from dace.libraries.onnx.nodes.onnx_op import ONNXOp + + +@make_registry +class ONNXForward(abc.ABC): + """ + Abstract base class for ONNX operation forward implementations. + + This class defines the interface for implementing ONNX operations in DaCe. + Subclasses must implement the `forward` method to expand an ONNX operation + node into DaCe SDFG constructs. + + The registry system allows multiple implementations per operation, with + selection based on applicability criteria. + """ + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check whether this implementation can be applied to the given node. + + This method is called during SDFG expansion to determine if this + implementation is suitable for the given context. The default + implementation returns True (always applicable). + + :param node: The ONNX operation node to expand. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if this implementation can be applied, False otherwise. + """ + return True + + @staticmethod + @abc.abstractmethod + def forward(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + """Expand an ONNX operation node into DaCe SDFG constructs. + + This is the main method that must be implemented by subclasses. It takes + an ONNX operation node and replaces it with equivalent DaCe constructs + (tasklets, nested SDFGs, library nodes, etc.). + + :param node: The ONNX operation node to expand. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: The expanded node or a nested SDFG representing the operation. + """ + ... + + @classmethod + def registered_implementations(cls, op_name: str) -> typing.List[typing.Tuple[str, "ONNXForward"]]: + """Get all registered implementations for a specific ONNX operation. + + :param op_name: The ONNX operation name (e.g., "Conv", "MatMul"). + :return: List of tuples (implementation_name, implementation_class) for + all registered implementations of the given operation. + """ + impls = [] + for impl, args in cls.extensions().items(): + if "op" in args and args["op"] == op_name: + impls.append((args["name"], impl)) + + return impls + + +# Import op_implementations to trigger registration of all implementations +import dace.libraries.onnx.op_implementations diff --git a/dace/libraries/onnx/nodes/__init__.py b/dace/libraries/onnx/nodes/__init__.py new file mode 100644 index 0000000000..0ed1814fe5 --- /dev/null +++ b/dace/libraries/onnx/nodes/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .onnx_op_registry import get_onnx_node, has_onnx_node diff --git a/dace/libraries/onnx/nodes/node_utils.py b/dace/libraries/onnx/nodes/node_utils.py new file mode 100644 index 0000000000..d377660f65 --- /dev/null +++ b/dace/libraries/onnx/nodes/node_utils.py @@ -0,0 +1,90 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Utility functions for ONNX node operations. + +This module provides helper functions for working with ONNX operation nodes +in DaCe SDFGs, including: + +- Parsing variadic parameter names +- Validating parameter formats +- Schema utilities for ONNX operations + +These utilities support the ONNX node system by handling the complexities +of variadic inputs/outputs and parameter naming conventions. +""" + +from typing import Tuple + +from dace.libraries.onnx.schema import ONNXParameterType, ONNXSchema + + +def parse_variadic_param(param: str) -> Tuple[str, int]: + """Parse a variadic parameter name into its base name and index. + + ONNX operations can have variadic inputs/outputs, which are named using + the convention 'base_name__index' (e.g., 'input__0', 'input__1'). + This function extracts the base name and numeric index. + + :param param: The variadic parameter name in format 'name__number'. + :return: A tuple of (base_name, index) where base_name is the parameter name + and index is the variadic position (zero-indexed). + :raises ValueError: If the parameter format is invalid, has leading zeros + in the number, or the number is negative. + + Example:: + + >>> parse_variadic_param("input__0") + ('input', 0) + >>> parse_variadic_param("output__5") + ('output', 5) + >>> parse_variadic_param("input__01") # raises ValueError + """ + split = param.split('__') + if len(split) != 2: + raise ValueError("Unable to parse variadic parameter '{}'".format(param)) + name = split[0] + number = split[1] + + if number[0] == '0' and len(number) > 1: + raise ValueError("Variadic parameters must not be numbered with leading zeros, got: '{}'".format(number)) + + number = int(number) + if number < 0: + raise ValueError("Variadic parameter numberings must be greater than zero, got: '{}'".format(number)) + return name, number + + +def get_position(schema: ONNXSchema, is_input: bool, parameter_name: str): + """Get the position that the parameter has in the ONNX op. + + :param schema: The ONNX schema containing parameter definitions. + :param is_input: True if looking for input parameters, False for output parameters. + :param parameter_name: The name of the parameter to find position for. + :return: The position index of the parameter in the operation signature. + :raises ValueError: If parameter is not found, has incorrect variadic format, + or schema validation fails. + """ + if "__" in parameter_name: + parameter_name, variadic_number = parse_variadic_param(parameter_name) + else: + variadic_number = None + + matches = [(i, param) for i, param in enumerate(schema.inputs if is_input else schema.outputs) + if param.name == parameter_name] + if len(matches) != 1: + raise ValueError("Error in schema: found more or less than one parameter with name {}".format(parameter_name)) + + index, param = matches[0] + + if variadic_number is not None and param.param_type != ONNXParameterType.Variadic: + raise ValueError("Got variadic index for non-variadic parameter {}".format(parameter_name)) + + if variadic_number is None and param.param_type == ONNXParameterType.Variadic: + raise ValueError("Did not get variadic index for variadic parameter {}. " + "Specify a variadic index by renaming the parameter to {}__i, where i is a number".format( + parameter_name, parameter_name)) + + if variadic_number is not None: + return variadic_number + index + else: + return index diff --git a/dace/libraries/onnx/nodes/onnx_op.py b/dace/libraries/onnx/nodes/onnx_op.py new file mode 100644 index 0000000000..e448bf8443 --- /dev/null +++ b/dace/libraries/onnx/nodes/onnx_op.py @@ -0,0 +1,295 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import itertools +from typing import Iterator, Tuple, List + +import dace.sdfg.nodes as nd +from dace.sdfg import SDFG, SDFGState +from dace.properties import Property, make_properties +from dace.sdfg.graph import MultiConnectorEdge + +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.libraries.onnx.schema import ONNXSchema, ONNXParameterType + + +def get_missing_arguments_message(function_name, missing_arguments, argument_type): + names = list(map(lambda x: "'" + x + "'", missing_arguments)) + + if len(missing_arguments) == 1: + arglist = names[0] + else: + arglist = ", ".join(names[:-1]) + ", and " + names[-1] + + return "{function_name} missing {num_missing} required {argument_type}{s}: {arglist}".format( + function_name=function_name, + num_missing=len(missing_arguments), + argument_type=argument_type, + s='' if len(missing_arguments) == 1 else 's', + arglist=arglist) + + +@make_properties +class ONNXOp(nd.LibraryNode): + """ Abstract superclass for all ONNX ops. Do not use this class, use the concrete subclasses + (e.g. :class:`~dace.libraries.onnx.nodes.onnx_op.ONNXConv`) instead. + """ + + # Global properties + # these two are filled out in the generated constructor + implementations = {} + default_implementation = None + default_backward_implementation = None + + # Object fields + schema = Property(dtype=ONNXSchema, desc="The operator's ONNX OpSchema", allow_none=True) + + backward_implementation = Property( + dtype=str, + allow_none=True, + desc="Which implementation this library node will expand into in the backward pass.") + + def iter_outputs_in_onnx_order(self, state: SDFGState) -> List[MultiConnectorEdge]: + """ Iterate through the input edges in the same order as they would appear in an ONNX node proto. + This assumes that the node has been validated! + + :param state: the state containing this node. + :return: the out edges in the order as they would appear in the node proto. + """ + return self._iter_params_in_onnx_order(state, inputs=False) + + def iter_inputs_in_onnx_order(self, state: SDFGState) -> List[MultiConnectorEdge]: + """ Iterate through the output edges in the same order as they would appear in an ONNX node proto. + This assumes that the node has been validated! + + :param state: the state containing this node. + :return: the in edges in the order as they would appear in the node proto. + """ + return self._iter_params_in_onnx_order(state, inputs=True) + + def _iter_params_in_onnx_order(self, state: SDFGState, inputs: bool = False) -> List[MultiConnectorEdge]: + parameters = list(self.schema.inputs if inputs else self.schema.outputs) + if len(parameters) == 0: + return [] + if parameters[-1].param_type == ONNXParameterType.Variadic: + name = parameters[-1].name + parameters = itertools.chain([param.name for param in parameters[:-1]], + (name + "__" + str(i) for i in itertools.count())) + else: + parameters = [param.name for param in parameters] + + edges = state.in_edges(self) if inputs else state.out_edges(self) + parameters = list(itertools.islice(parameters, len(edges))) + conn_to_edge = {edge.dst_conn if inputs else edge.src_conn: edge for edge in edges} + + return [conn_to_edge[name] for name in parameters] + + def iter_edges( + self, + state: SDFGState, + ignore_unknown=False, + ) -> Iterator[Tuple[MultiConnectorEdge, bool]]: + """ Returns an iterator over tuples of an edge and a boolean that indicates whether that edge is an input, + ordered by the order required by the schema. + This method assumes that this node has been validated. + + :param state: the state containing this node. + :param ignore_unknown: whether to ignore any edges that don't exist in the ONNX schema. Otherwise, an + error will be thrown. + """ + in_edges: List[MultiConnectorEdge] = state.in_edges(self) + out_edges: List[MultiConnectorEdge] = state.out_edges(self) + + def get_idx(parameters, name): + if '__' in name: + name, number = parse_variadic_param(name) + else: + number = 0 + + matched = [i for i, param in enumerate(parameters) if param.name == name] + + if len(matched) != 1: + if ignore_unknown: + return None + raise ValueError("Found {} connectors with name '{}', expected to find exactly one".format( + len(matched), name)) + + parameter_idx = matched[0] + + # add on the variadic parameter index + parameter_idx += number + + return parameter_idx + + if ignore_unknown: + in_edges = [e for e in in_edges if get_idx(self.schema.inputs, e.dst_conn) is not None] + out_edges = [e for e in out_edges if get_idx(self.schema.outputs, e.src_conn) is not None] + + sorted_in = sorted(in_edges, key=lambda edge: get_idx(self.schema.inputs, edge.dst_conn)) + sorted_out = sorted(out_edges, key=lambda edge: get_idx(self.schema.outputs, edge.src_conn)) + + return itertools.chain(zip(sorted_in, itertools.repeat(True)), zip(sorted_out, itertools.repeat(False))) + + def validate(self, sdfg: SDFG, state: SDFGState): + """ Validate this node. + + :param sdfg: the parent sdfg. + :param state: the parent state. + """ + in_edges = state.in_edges(self) + out_edges = state.out_edges(self) + + # check that we don't have connectors to None + all_connectors = {edge.dst_conn for edge in in_edges}.union(edge.src_conn for edge in out_edges) + if None in all_connectors: + raise ValueError("Edges to ONNX Ops must not have connector None") + + # check that all edges have connectors + ########################################## + for edge, is_input in self.iter_edges(state): + if is_input: + conn_name = edge.dst_conn + if conn_name not in self.in_connectors: + raise ValueError("Memlet {} leading to nonexistent input connector '{}'".format( + edge.data, conn_name)) + else: + conn_name = edge.src_conn + if conn_name not in self.out_connectors: + raise ValueError("Memlet {} leading to nonexistent output connector '{}'".format( + edge.data, conn_name)) + + # check that we have all required in_edges + ########################################## + required_inputs = {inp.name for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Single} + passed_inputs = {inp.dst_conn + for inp in in_edges if '__' not in inp.dst_conn} # we will test variadic inputs separately + known_inputs = {inp.name for inp in self.schema.inputs} + + missing_inputs = required_inputs.difference(passed_inputs) + if len(missing_inputs) > 0: + raise ValueError(get_missing_arguments_message(self.schema.name, missing_inputs, "input")) + + # check that we have all required out_edges + ########################################## + required_outputs = {outp.name for outp in self.schema.outputs if outp.param_type == ONNXParameterType.Single} + passed_outputs = {outp.src_conn + for outp in out_edges if '__' not in outp.src_conn} # we will test variadic inputs separately + known_outputs = {outp.name for outp in self.schema.outputs} + + missing_outputs = required_outputs.difference(passed_outputs) + if len(missing_outputs) > 0: + raise ValueError(get_missing_arguments_message(self.schema.name, missing_outputs, "output")) + + # check that we have no unknown in edges + ########################################## + unknown_inputs = passed_inputs.difference(known_inputs) + if len(unknown_inputs) > 0: + raise TypeError("Got an unexpected argument '{}'".format(list(unknown_inputs)[0])) + + # check that we have no unknown out edges + ########################################## + unknown_outputs = passed_outputs.difference(known_outputs) + if len(unknown_outputs) > 0: + raise TypeError("Got an unexpected argument '{}'".format(list(unknown_outputs)[0])) + + # check variadic params + ########################################## + variadic_inputs = {inp.name for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Variadic} + passed_variadic_inputs = {edge.dst_conn for edge in in_edges if '__' in edge.dst_conn} + + seen_variadic_numbers = set() + for param in passed_variadic_inputs: + name, number = parse_variadic_param(param) + if name not in variadic_inputs: + raise ValueError("Got an unexpected variadic argument '{}'".format(param)) + if number in seen_variadic_numbers: + raise ValueError("Got two variadic inputs with index {}, expected at most one".format(number)) + seen_variadic_numbers.add(number) + + # check that we have seen every number + for i in range(len(seen_variadic_numbers)): + if i not in seen_variadic_numbers: + raise ValueError( + "Since {} variadic inputs were passed, expected variadic parameter with number {}".format( + len(seen_variadic_numbers), i)) + + variadic_outputs = {outp.name for outp in self.schema.outputs if outp.param_type == ONNXParameterType.Variadic} + passed_variadic_outputs = {edge.src_conn for edge in out_edges if '__' in edge.src_conn} + seen_variadic_numbers = set() + for param in passed_variadic_outputs: + name, number = parse_variadic_param(param) + if name not in variadic_outputs: + raise ValueError("Got an unexpected variadic argument '{}'".format(param)) + if number in seen_variadic_numbers: + raise ValueError("Got two variadic outputs with index {}, expected at most one".format(number)) + seen_variadic_numbers.add(number) + + # check that we have seen every number + for i in range(len(seen_variadic_numbers)): + if i not in seen_variadic_numbers: + raise ValueError( + "Since {} variadic outputs were passed, expected variadic parameter with number {}".format( + len(seen_variadic_numbers), i)) + + # check that type params solve + ########################################## + + assigned_params = {} + for edge, is_input in self.iter_edges(state): + conn_name = edge.dst_conn if is_input else edge.src_conn + + if '__' in conn_name: + parsed_name, number = parse_variadic_param(conn_name) + else: + parsed_name = conn_name + + matching = [ + inp for inp in (self.schema.inputs if is_input else self.schema.outputs) if inp.name == parsed_name + ] + + if len(matching) != 1: + raise ValueError("Expected to find one {} parameter in schema with name '{}', but found {}".format( + "input" if is_input else "output", parsed_name, len(matching))) + matched = matching[0] + + if '__' in conn_name and matched.param_type != ONNXParameterType.Variadic: + raise ValueError("Got variadic argument '{}' for non-variadic parameter '{}'." + " Ensure that non-variadic args do not contain '__'".format(conn_name, matched.name)) + + if '__' not in conn_name and matched.param_type == ONNXParameterType.Variadic: + raise ValueError( + "Expected variadic argument for variadic parameter '{}', got '{}'. Use '{}__i' as the connector" + " name, where i is the desired index of the variadic parameter.".format( + matched.name, conn_name, conn_name)) + + edge_data = edge.data.data + edge_dtype = sdfg.arrays[edge_data].dtype + # edge_dtype can be a vector type + if matched.param_type == ONNXParameterType.Variadic and not matched.homogeneous: + # non homogeneous parameters don't need to be consistent + pass + elif matched.type_str in assigned_params and (assigned_params[matched.type_str] != edge_dtype and + assigned_params[matched.type_str] != edge_dtype.base_type): + raise ValueError( + "Could not solve type constraints;" + " excepted type '{expected}' for {param_type} '{conn_name}', got type '{actual}'".format( + expected=assigned_params[matched.type_str], + param_type="input" if is_input else "output", + conn_name=matched.name, + actual=edge_dtype)) + + # otherwise, matched.type_str was not assigned a type yet: try to assign it + cons = self.schema.type_constraints[matched.type_str] + if edge_dtype not in cons.types and edge_dtype.base_type not in cons.types: + raise ValueError( + "Expected type in '{possible}' for {param_type} '{conn_name}', got type '{actual}'".format( + possible=cons.types, + param_type="input" if is_input else "output", + conn_name=matched.name, + actual=edge_dtype)) + assigned_params[matched.type_str] = edge_dtype.base_type + + # check that we have all required attributes + ########################################## + required_attrs = {name for name, attr in self.schema.attributes.items() if attr.required} + for attr in required_attrs: + if getattr(self, attr) is None: + raise ValueError("Expected value for required attribute '{}', got None".format(attr)) diff --git a/dace/libraries/onnx/nodes/onnx_op_registry.py b/dace/libraries/onnx/nodes/onnx_op_registry.py new file mode 100644 index 0000000000..6b145a0859 --- /dev/null +++ b/dace/libraries/onnx/nodes/onnx_op_registry.py @@ -0,0 +1,351 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import collections +from typing import Iterator, Tuple, List, Dict, Type + +import dace +import dace.library +import dace.sdfg.nodes as nd +import dace.frontend.common.op_repository as dace_op_repo +from dace.frontend.python.newast import ProgramVisitor +from dace import config, SDFG, SDFGState, dtypes, data +from dace.properties import Property, ListProperty, make_properties +from dace.sdfg.graph import MultiConnectorEdge +from dace.transformation.transformation import ExpandTransformation + +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.libraries.onnx.schema import ONNXSchema, ONNXAttributeType, _ATTR_TYPE_TO_PYTHON_TYPE, ONNXParameterType, ONNXAttribute, ONNXParameter, ONNXTypeConstraint + +import dace.libraries.onnx.nodes.onnx_op as onnx_op +from dace.frontend.python.common import StringLiteral + +import onnx + + +def _get_typecons_docstring(cons: ONNXTypeConstraint) -> str: + """Generate documentation string for type constraints.""" + return " * **{}** -- {}".format(cons.type_str, + ", ".join(":class:`{}`".format(t.to_string()) for t in cons.types)) + + +def _get_connector_docstring(param: ONNXParameter) -> str: + """Generate documentation string for connectors.""" + return " * **{}** ({}, {}) -- {}".format(param.name, param.type_str, param.param_type.name.lower(), + param.description) + + +def _get_attr_docstring(attr: ONNXAttribute) -> str: + """Generate documentation string for attributes.""" + param_doc = ":param {}: {}".format(attr.name, attr.description) + + if attr.attribute_type is ONNXAttributeType.Unsupported: + return "" + + if attr.attribute_type is ONNXAttributeType.Tensor: + type_string = "numpy.ndarray" + else: + type_string = _ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type].__name__ + + type_string = ":class:`{}`".format(type_string) + + if attr.attribute_type in [ONNXAttributeType.Ints, ONNXAttributeType.Floats, ONNXAttributeType.Strings]: + type_string = ":class:`List` [{}]".format(type_string) + + if not attr.required: + type_string = ":class:`Optional` [{}], default={}".format(type_string, repr(attr.default_value)) + + param_type = ":type {}: {}".format(attr.name, type_string) + + return param_doc + "\n" + param_type + + +def _get_all_schemas(): + """Get all ONNX schemas with version history.""" + name_to_schemas = collections.defaultdict(list) + for schema in onnx.defs.get_all_schemas_with_history(): + name_to_schemas[schema.name].append(schema) + + all_schemas = [] + for name, schemas in name_to_schemas.items(): + all_schemas.extend(schemas) + + return all_schemas + + +def register_op_repo_replacement(cls: Type[onnx_op.ONNXOp], cls_name: str, dace_schema: ONNXSchema): + """Register an op repository replacement for the given ONNX operation class.""" + + @dace_op_repo.replaces("dace.libraries.onnx.{}".format(cls_name)) + def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, **kwargs): + attrs = {name: value for name, value in kwargs.items() if name in dace_schema.attributes} + # Remove used attrs + kwargs = {k: v for k, v in kwargs.items() if k not in attrs} + + onnx_node = cls(name=cls_name, **attrs) + state.add_node(onnx_node) + + input_names = dace_schema.non_variadic_inputs() + variadic_inputs = dace_schema.variadic_inputs() + + output_names = dace_schema.non_variadic_outputs() + variadic_outputs = dace_schema.variadic_outputs() + + inputs = { + name: arr_name + for name, arr_name in kwargs.items() + if (name in input_names or + # variadic params + ("__" in name and parse_variadic_param(name)[0] in variadic_inputs)) + } + + kwargs = {k: v for k, v in kwargs.items() if k not in inputs} + + outputs = { + name: arr_name + for name, arr_name in kwargs.items() + if (name in output_names or + # variadic params + ("__" in name and parse_variadic_param(name)[0] in variadic_outputs)) + } + + kwargs = {k: v for k, v in kwargs.items() if k not in outputs} + + if len(kwargs) > 0: + raise TypeError(f"Unknown arguments {', '.join(kwargs)}") + + # Remove all non-string attributes + # Sometimes constants are passed as inputs, but they do not require AccessNodes + # so we add them first as attributes to the node + for inp, arr_name in inputs.items(): + if not isinstance(arr_name, str): + setattr(onnx_node, inp, arr_name) + + inputs = {inp: arr_name for inp, arr_name in inputs.items() if isinstance(arr_name, str)} + + for inp, arr_name in inputs.items(): + read = state.add_read(arr_name) + state.add_edge(read, None, onnx_node, inp, sdfg.make_array_memlet(arr_name)) + onnx_node.add_in_connector(inp) + + for outp, arr_name in outputs.items(): + write = state.add_read(arr_name) + state.add_edge(onnx_node, outp, write, None, sdfg.make_array_memlet(arr_name)) + onnx_node.add_out_connector(outp) + return [] + + +_ONNX_OPS = {} +_REGISTRY_INITIALIZED = False + + +def _initialize_onnx_registry(): + """ + Lazy initialization of ONNX operator registry. + This function is called only when ONNX nodes are actually used. + We add a global flag (_REGISTRY_INITIALIZED) to avoid re-initializing the registry multiple times. + """ + global _REGISTRY_INITIALIZED, _ONNX_OPS + + if _REGISTRY_INITIALIZED: + return + + _REGISTRY_INITIALIZED = True + + # Import these here to avoid circular imports at module load time + from dace.libraries.onnx.forward_implementation_abc import ONNXForward + import dace.libraries.onnx.op_implementations # Registers implementations + + # Generate all of the Op Nodes + for schema in _get_all_schemas(): + try: + dace_schema = ONNXSchema.from_onnx_proto(schema) + # If the schema has a parameter name that exists as both an input and an output, prepend "in_" and "out_" + intersecting_names = set(i.name for i in dace_schema.inputs).intersection(o.name + for o in dace_schema.outputs) + for name in intersecting_names: + in_cands = [i for i in dace_schema.inputs if i.name == name] + out_cands = [i for i in dace_schema.outputs if i.name == name] + assert len(in_cands) == len(out_cands) == 1 + in_cands[0].name = "in_" + name + out_cands[0].name = "out_" + name + + except Exception as e: + if config.Config.get_bool('debugprint'): + print("Import of {} failed: {}".format(schema.name, e)) + continue + + attrs = {} + # Add properties for each op attribute + for name, attr in dace_schema.attributes.items(): + if attr.attribute_type in [ + ONNXAttributeType.Int, ONNXAttributeType.String, ONNXAttributeType.Float, ONNXAttributeType.Tensor + ]: + attrs[name] = Property(dtype=_ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type], + desc=attr.description, + allow_none=True, + default=None if attr.default_value is None else attr.default_value) + elif attr.attribute_type in [ONNXAttributeType.Ints, ONNXAttributeType.Strings, ONNXAttributeType.Floats]: + attrs[name] = ListProperty(element_type=_ATTR_TYPE_TO_PYTHON_TYPE[attr.attribute_type], + desc=attr.description, + allow_none=True, + default=None if attr.default_value is None else attr.default_value) + elif attr.required: + raise NotImplementedError("Required attribute '{}' has an unsupported type".format(attr.name)) + + required_attrs = {name for name, attr in dace_schema.attributes.items() if attr.required} + + def __init__(self, name, *args, location=None, optional=set(), **op_attributes): + super(onnx_op.ONNXOp, self).__init__( + name, + location=location, + # Add required parameters as in/out connectors, without types for now + inputs={ + inp.name + for inp in self.schema.inputs if inp.param_type == ONNXParameterType.Single or ( + inp.name in optional and inp.param_type == ONNXParameterType.Optional) + }, + outputs={ + out.name + for out in self.schema.outputs if out.param_type == ONNXParameterType.Single or ( + out.name in optional and out.param_type == ONNXParameterType.Optional) + }) + self.backward_implementation = None + + if len(args) > 0: + raise TypeError("__init__() takes 1 positional arguments but {} were given".format(1 + len(args))) + + missing_arguments = required_attrs.difference(op_attributes) + if len(missing_arguments) > 0: + + raise TypeError( + onnx_op.get_missing_arguments_message("__init__()", missing_arguments, "keyword-only argument")) + + unknown_attrs = set(op_attributes).difference(self.schema.attributes) + if len(unknown_attrs) > 0: + raise TypeError("{}.__init__() got an unexpected keyword argument '{}'".format( + self.schema.name, + list(unknown_attrs)[0])) + + for name, attr in op_attributes.items(): + if isinstance(attr, StringLiteral): + attr = attr.value + setattr(self, name, attr) + + input_connector_docstrings = "\n".join(_get_connector_docstring(param) for param in dace_schema.inputs) + output_connector_docstrings = "\n".join(_get_connector_docstring(param) for param in dace_schema.outputs) + + cls_name = "ONNX" + dace_schema.name + + # The first line of the init docstring contains the signature of the method. This will be picked up by sphinx and + # means that the generated sphinx docs have a proper signature, and not just *args, **kwargs. + init_docstring = "__init__(name, *, {})\n".format(", ".join(attr.name if attr.required else attr.name + "=" + + repr(attr.default_value) + for _, attr in dace_schema.attributes.items())) + init_docstring += ":param name: The name of the node.\n" + "\n".join( + _get_attr_docstring(attr) for _, attr in dace_schema.attributes.items()) + + docstring = "\n" + dace_schema.doc + type_docstrings = "\n".join(_get_typecons_docstring(cons) for _, cons in dace_schema.type_constraints.items()) + docstring += "\n\n" + docstring += ":Node Inputs:" + input_connector_docstrings + docstring += "\n\n" + docstring += ":Node Outputs:" + output_connector_docstrings + docstring += "\n\n" + docstring += ":Type Constraints:" + type_docstrings + + attrs['__doc__'] = docstring + "\n" + attrs['schema'] = dace_schema + + attrs['__init__'] = __init__ + + cls_name_ver = cls_name + "_" + str(dace_schema.since_version) + + cls = type(cls_name_ver, (onnx_op.ONNXOp, ), attrs) + cls = dace.library.node(cls) + cls.__init__.__doc__ = "\n" + init_docstring + # Set library name for lazy-loaded nodes + cls._dace_library_name = "dace.libraries.onnx" + + # Register pure implementations + registered = False + for impl, args in ONNXForward.extensions().items(): + if "op" in args and args["op"] == schema.name: + + class Expansion(ExpandTransformation): + environments = [] + forward_impl: ONNXForward = impl + + @classmethod + def expansion(cls, node, state, sdfg, **kwargs): + # validate + node.validate(sdfg, state) + + if cls.forward_impl.forward_can_be_applied(node, state, sdfg): + result = cls.forward_impl.forward(node, state, sdfg, **kwargs) + if hasattr(cls.forward_impl, "environments"): + cls.environments.extend(cls.forward_impl.environments) + return result + + implementation_name = args["name"] + + # Give the Expansion class a unique name and register it in globals + # so it can be located during deserialization + expansion_class_name = f"{cls_name_ver}_Expansion_{implementation_name}" + Expansion.__name__ = expansion_class_name + Expansion.__qualname__ = expansion_class_name + globals()[expansion_class_name] = Expansion + + cls.register_implementation(implementation_name, Expansion) + registered = True + + if not registered: + # WARNING: No implementation found for this op + cls.default_implementation = None + + version = schema.since_version + + if cls_name not in _ONNX_OPS: + _ONNX_OPS[cls_name] = {} + _ONNX_OPS[cls_name][version] = cls + + for name, ver_to_cls in _ONNX_OPS.items(): + _ONNX_OPS[name] = dict(sorted(ver_to_cls.items())) + for i, (version, cls) in enumerate(_ONNX_OPS[name].items()): + if i == len(_ONNX_OPS[name]) - 1: + # last version registered as the default + globals()[name] = cls + # register python frontend replacement + register_op_repo_replacement(cls, name, cls.schema) + # all other versions are registered with version as a suffix + globals()[name + "_" + str(version)] = cls + + +def has_onnx_node(name: str) -> bool: + """Check if an ONNX operator is supported. + + :param name: The operator name. + :return: True if the operator is supported, False otherwise. + """ + _initialize_onnx_registry() + return ("ONNX" + name) in _ONNX_OPS + + +def get_onnx_node(name: str, version: int = -1) -> onnx_op.ONNXOp: + """Get the ONNX Operator node for an operator by name. + + :param name: The operator name. + :param version: The version of the operator (-1 for latest). + :return: The ONNX operator node class. + :raises ValueError: If no version of the operator is found for the given version. + """ + _initialize_onnx_registry() + name_to_versions = list(_ONNX_OPS["ONNX" + name].items()) + + if version == -1: + # Take the latest version + return name_to_versions[-1][1] + else: + # Take the latest version which is less than or equal to the given version + for ver, cls in reversed(name_to_versions): + if ver <= version: + return cls + raise ValueError(f"No version of {name} found for version {version}") diff --git a/dace/libraries/onnx/onnx.md b/dace/libraries/onnx/onnx.md new file mode 100644 index 0000000000..0f43d037a3 --- /dev/null +++ b/dace/libraries/onnx/onnx.md @@ -0,0 +1,993 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe ONNX Integration Library - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Architecture Overview](#2-architecture-overview) +3. [Directory Structure](#3-directory-structure) +4. [Core Components](#4-core-components) +5. [Import Pipeline](#5-import-pipeline) +6. [Shape Inference System](#6-shape-inference-system) +7. [Implementation Strategies](#7-implementation-strategies) +8. [Key Algorithms](#8-key-algorithms) +9. [Extension Points](#9-extension-points) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe ONNX Integration Library enables **the importing and executing of ONNX (Open Neural Network Exchange) models** within the DaCe framework. It provides a pipeline for converting ONNX neural network models into optimized DaCe SDFGs (Stateful DataFlow Graphs) that can run efficiently on CPUs, GPUs, and other accelerators. + +### 1.2 Current Capabilities + +- **Model Import**: Load ONNX models from files or protobuf objects +- **Shape Inference**: Automatic computation of tensor shapes (symbolic and concrete) +- **Multi-Strategy Implementations**: Pure (correctness), optimized (performance), hardware-specific (GPU/FPGA) +- **Type Safety**: Schema-based validation and type checking +- **Framework Integration**: Interoperability with PyTorch and NumPy + +### 1.3 Use Cases + +1. **ML Inference Optimization**: Optimize pre-trained models for production deployment +2. **Hardware Acceleration**: Leverage DaCe's code generation for GPU/FPGA execution +3. **Cross-Framework Compatibility**: Run PyTorch/TensorFlow models in DaCe ecosystem +4. **Research and Experimentation**: Analyze and optimize neural network architectures +5. **Custom Optimization**: Apply DaCe transformations to ML workloads +6. **Benchmarking**: Compare performance across different implementations + +### 1.4 ONNX Background + +ONNX is an open standard for representing machine learning models, supported by major frameworks: +- **Export**: PyTorch, TensorFlow, Keras, scikit-learn +- **Operators**: 150+ standard operations (Conv, MatMul, Attention, etc.) +- **Opsets**: Versioned operator specifications (current: opset 18) +- **Use**: Model exchange, optimization, deployment + +--- + +## 2. Architecture Overview + +### 2.1 High-Level System Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ USER INTERFACE │ +│ ┌──────────────┐ ┌──────────────┐ ┌─────────────────┐ │ +│ │ ONNXModel │ │ ONNX Backend │ │ Direct ONNX Op │ │ +│ │ (main API) │ │ (testing) │ │ calls │ │ +│ └──────┬───────┘ └──────┬───────┘ └────────┬────────┘ │ +└─────────┼─────────────────┼───────────────────┼─────────────┘ + │ │ │ + └─────────────────┼───────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ IMPORT PIPELINE │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ ONNXModel (frontend/ml/onnx/importer.py) │ │ +│ │ 1. Load Model → 4. Graph Construction │ │ +│ │ 2. Simplify → 5. Weight Management │ │ +│ │ 3. Shape Infer → 6. Compilation │ │ +│ └──────────────────┬───────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────────┘ + │ + ┌────────────┼───────────┐ + ▼ ▼ ▼ +┌──────────────┐ ┌─────────┐ ┌──────────────────┐ +│ REGISTRY │ │ SCHEMA │ │ SHAPE INFERENCE │ +├──────────────┤ ├─────────┤ ├──────────────────┤ +│ Dynamic Node │ │ Type │ │ Symbolic Shape │ +│ Generation │ │ System │ │ Inference │ +│ │ │ │ │ (Microsoft impl) │ +│ • 100+ ops │ │ • Valid-│ │ │ +│ • Versioning │ │ ation │ │ • Dynamic dims │ +│ • Properties │ │ • Const-│ │ • Auto-merge │ +│ • Connectors │ │ raints│ │ • Concrete eval │ +└──────────────┘ └─────────┘ └──────────────────┘ + │ │ │ + └────────────┼────────────┘ + ▼ +┌────────────────────────────────────────────────────────────┐ +│ IMPLEMENTATION LAYER │ +│ ┌─────────────┐ ┌──────────────┐ ┌──────────────────┐ │ +│ │ Pure │ │ Optimized │ │ Hardware │ │ +│ │ (SDFGs) │ │ (img ops) │ │ (cuDNN, etc) │ │ +│ ├─────────────┤ ├──────────────┤ ├──────────────────┤ │ +│ │ Reference │ │ Performance │ │ GPU/FPGA │ │ +│ │ impl for │ │ focused │ │ specialized │ │ +│ │ correctness │ │ operations │ │ libraries │ │ +│ └─────────────┘ └──────────────┘ └──────────────────┘ │ +└────────────────────────────────────────────────────────────┘ + ▼ + DaCe SDFG with ONNX Nodes + ▼ + Expansion → Optimization → Code Generation +``` + +### 2.2 Component Interaction Flow + +``` +ONNX Model File + ↓ +ONNXModel.__init__() + ↓ +1. onnx.checker.check_model() → Validate + ↓ +2. shape_inference.infer_shapes() → Compute shapes + ↓ +3. onnxsim.simplify() (optional) → Optimize ONNX graph + ↓ +4. Create SDFG structure + ↓ +5. For each ONNX node: + ├─→ get_onnx_node(op_type, version) → Retrieve node class + ├─→ Create instance with attributes + ├─→ Add connectors from schema + └─→ Create edges with memlets + ↓ +6. Load weights (initializers) + ↓ +7. Handle outputs (scalar promotion, return arrays) + ↓ +8. Apply GPU transformations (if cuda=True) + ↓ +SDFG with ONNX Library Nodes + ↓ +Compilation (triggered by first call or explicit compile()): + ├─→ Expand ONNX nodes (select implementation) + ├─→ Apply DaCe optimizations + ├─→ Generate C++/CUDA code + └─→ Compile to binary + ↓ +Execution: + ├─→ Infer runtime symbols from input shapes + ├─→ Call compiled function with inputs + weights + └─→ Return outputs (NumPy or PyTorch tensors) +``` + +--- + +## 3. Directory Structure + +### 3.1 File Organization + +``` +dace/libraries/onnx/ +├── __init__.py # Library registration (61 lines) +│ └── Exports: ONNXModel (lazy), get_onnx_node, has_onnx_node, schema types +│ +├── (Note: ONNXModel is at dace/frontend/ml/onnx/importer.py) +│ +├── schema.py # Type system +│ ├── @onnx_representation decorator +│ ├── ONNXSchema +│ ├── ONNXAttribute +│ ├── ONNXParameter +│ └── ONNXTypeConstraint +│ +├── converters.py # Type conversions +│ ├── convert_onnx_proto() +│ ├── onnx_tensor_type_to_typeclass() +│ ├── clean_onnx_name() +│ └── convert_attribute_proto() +│ +├── forward_implementation_abc.py # Implementation interface +│ └── ONNXForward (ABC + registry) +│ +├── nodes/ # ONNX operation nodes +│ ├── onnx_op.py # Base class +│ │ └── ONNXOp - Abstract superclass for all ONNX ops +│ ├── onnx_op_registry.py # Dynamic generation +│ │ ├── _get_all_schemas() +│ │ ├── _create_node_class() +│ │ └── get_onnx_node() / has_onnx_node() +│ └── node_utils.py # Utilities +│ └── parse_variadic_param() +│ +├── op_implementations/ # Implementation strategies +│ ├── __init__.py # Package exports (11 lines) +│ ├── elementwise_ops.py # Element-wise operations (212 lines) +│ ├── reduction_ops.py # Reduction operations (304 lines) +│ ├── array_ops.py # Array operations (681 lines) +│ ├── linalg_ops.py # Linear algebra ops (359 lines) +│ ├── normalization_ops.py # Normalization ops (281 lines) +│ ├── image_ops.py # Image operations (443 lines) +│ ├── img_op_implementations.py # Optimized image ops (563 lines) +│ ├── criteria_implementations.py # Conditional selection (90 lines) +│ ├── common.py # Common utilities (11 lines) +│ └── utils.py # Helpers (223 lines) +│ ├── @op_implementation decorator +│ ├── @python_pure_op_implementation +│ ├── program_for_node() +│ └── empty_sdfg_for_node() +│ +└── shape_inference/ # Dynamic shape support + └── (Empty - uses onnxruntime.tools.symbolic_shape_infer instead) + +``` + +### 3.2 File Size Distribution + +| File | Lines | Purpose | +|------|-------|---------| +| `array_ops.py` | 681 | Array operations (Concat, Gather, etc.) | +| `img_op_implementations.py` | 563 | Optimized image operations | +| `image_ops.py` | 443 | Image operation implementations | +| `linalg_ops.py` | 359 | Linear algebra operations | +| `onnx_op_registry.py` | 351 | Dynamic node class generation | +| `schema.py` | 333 | Type system and validation | +| `reduction_ops.py` | 304 | Reduction operations | +| `onnx_op.py` | 295 | Base class for ONNX operations | +| `normalization_ops.py` | 281 | Normalization operations | +| `converters.py` | 247 | Type conversion utilities | +| `utils.py` | 223 | Implementation helpers | +| `elementwise_ops.py` | 212 | Element-wise operations | +| `forward_implementation_abc.py` | 105 | Implementation interface | + +**Note**: ONNXModel (794 lines) is located at [dace/frontend/ml/onnx/importer.py](../../frontend/ml/onnx/importer.py). + +--- + +## 4. Core Components + +### 4.1 ONNXModel: The Main Entry Point + +**Location**: [dace/frontend/ml/onnx/importer.py](../../frontend/ml/onnx/importer.py) + +The `ONNXModel` class is the primary interface for importing and executing ONNX models. + +#### Key Features + +- **Model Loading**: Loads models from files or ONNX protobuf objects +- **Automatic Optimization**: Provides optional ONNX-level simplification +- **Shape Inference**: Handles dynamic and symbolic shapes automatically +- **Weight Management**: Loads and manages model parameters efficiently +- **Compilation**: Supports lazy or explicit compilation to optimized code +- **Execution**: Provides direct `__call__` interface with NumPy/PyTorch tensors +- **GPU Support**: Automatic GPU transformation when `cuda=True` + +#### Constructor Signature + +```python +class ONNXModel: + def __init__( + self, + name: str, + model: Union[str, onnx.ModelProto], + cuda: bool = False, + apply_strict: bool = False, + auto_optimize: bool = True, + onnx_simplify: bool = True, + infer_shapes: bool = True, + auto_merge: bool = False + ): + """ + Import an ONNX model into DaCe. + + Args: + name: Name for the generated SDFG + model: Path to .onnx file or onnx.ModelProto object + cuda: Enable GPU execution + apply_strict: Strict ONNX validation + auto_optimize: Apply DaCe optimizations on first run + onnx_simplify: Apply onnx-simplifier before import + infer_shapes: Run shape inference + auto_merge: Auto-merge conflicting symbolic shapes + """ +``` + +#### Main Methods + +- **`__call__()`**: Execute the model with inputs +- **`compile()`**: Explicitly compile the SDFG +- **`save()`**: Save compiled model to disk +- **`infer_symbols()`**: Infer symbolic dimension values from input shapes + +--- + +### 4.2 Registry System: Dynamic Node Generation + +**Location**: [nodes/onnx_op_registry.py](nodes/onnx_op_registry.py) + +The registry system **dynamically generates Python classes** for all ONNX operations at import time, eliminating the need to manually write 100+ node classes. + +#### How It Works + +**Process**: +``` +1. Query ONNX for all supported operations + ↓ +2. For each operation (e.g., "Conv"): + ├─ Get all versions (e.g., Conv_1, Conv_11, Conv_13) + ├─ Convert ONNX OpSchema to ONNXSchema + └─ For each version: + ├─ Create Python properties from attributes + ├─ Generate __init__ constructor + ├─ Add input/output connectors + ├─ Generate documentation + ├─ Create class with type() + └─ Register with DaCe library system + ↓ +3. Store in global registry: + _ONNX_OPS["Conv"][11] = ONNXConv_11 + _ONNX_OPS["Conv"][13] = ONNXConv_13 + ↓ +4. Export latest version to module: + ONNXConv = ONNXConv_13 +``` + +#### Generated Class Structure + +For each ONNX operation, the registry generates: + +- **Class Name**: `ONNX{OpName}_{Version}` (e.g., `ONNXConv_11`) +- **Properties**: One DaCe property per ONNX attribute +- **Constructor**: Validates required attributes, sets defaults +- **Connectors**: Input/output connectors from schema +- **Schema**: Embedded `ONNXSchema` for validation +- **Implementations**: Linked expansion transformations +- **Documentation**: Auto-generated from ONNX docs + +#### API Functions + +```python +def has_onnx_node(name: str) -> bool: + """Check if ONNX operation is supported.""" + +def get_onnx_node(name: str, opset_version: int = None) -> Type[ONNXOp]: + """Get ONNX node class by name and version.""" +``` + +--- + +### 4.3 Schema System: Type Safety + +**Location**: [schema.py](schema.py) + +The schema system provides a Python representation layer for ONNX protobuf schemas, enabling type-safe interactions. + +#### Key Components + +**ONNXSchema** - Complete operation specification: +```python +@dataclass +class ONNXSchema: + name: str # Operation name (e.g., "Conv") + since_version: int # First opset supporting this + doc: str # Documentation + inputs: List[ONNXParameter] # Input specifications + outputs: List[ONNXParameter] # Output specifications + attributes: Dict[str, ONNXAttribute] # Attribute specs + type_constraints: Dict[str, ONNXTypeConstraint] # Type constraints +``` + +**ONNXParameter** - Input/output parameter: +```python +@dataclass +class ONNXParameter: + name: str # Parameter name + type_str: str # Type constraint reference + param_type: ONNXParameterType # Single/Optional/Variadic + description: str # Documentation + homogeneous: bool # For variadic params +``` + +**ONNXAttribute** - Operation configuration: +```python +@dataclass +class ONNXAttribute: + name: str # Attribute name + type: ONNXAttributeType # Int/Float/String/Tensor/etc. + required: bool # Must be provided? + default_value: Any # Default if not provided + description: str # Documentation +``` + +**ONNXTypeConstraint** - Allowed types: +```python +@dataclass +class ONNXTypeConstraint: + type_param_str: str # Type parameter (e.g., "T") + allowed_types: List[typeclass] # Allowed DaCe types + description: str # Documentation +``` + +#### The @onnx_representation Decorator + +Enables creating Python classes from ONNX protobufs: + +```python +@onnx_representation(onnx.TensorProto) +class ONNXTensor: + dims: List[int] + data_type: int + # ... other fields +``` + +Automatically generates: +- `__init__()` constructor +- `from_onnx_proto()` class method +- `from_json()` / `to_json()` serialization +- Registration in the global protobuf registry + +--- + +### 4.4 ONNXOp Base Class + +**Location**: [nodes/onnx_op.py](nodes/onnx_op.py) + +`ONNXOp` is the abstract base class for all ONNX operation nodes in DaCe SDFGs. + +#### Key Methods + +- **`iter_inputs_in_onnx_order()`**: Get input edges in schema order +- **`iter_outputs_in_onnx_order()`**: Get output edges in schema order +- **`iter_edges()`**: Iterate all edges with input/output flag +- **Validation**: Automatic schema-based validation during SDFG construction + +#### Properties + +- `schema`: The operation's ONNXSchema +- `backward_implementation`: Which backward impl to use (for autodiff) +- `implementations`: Available forward implementations +- `default_implementation`: Default expansion strategy + +--- + +### 4.5 Type Converters + +**Location**: [converters.py](converters.py) + +Provides bidirectional conversion between ONNX, DaCe, NumPy, and PyTorch type systems. + +#### Key Functions + +**Type Conversion**: +- `onnx_tensor_type_to_typeclass()`: ONNX type enum → DaCe typeclass +- `typeclass_to_onnx_tensor_type_int()`: DaCe typeclass → ONNX type enum +- `convert_onnx_proto()`: Generic protobuf → Python conversion +- `convert_attribute_proto()`: ONNX AttributeProto → Python value + +**Name Sanitization**: +- `clean_onnx_name()`: Makes ONNX names valid DaCe identifiers + - Prefixes digit-starting names: `123` → `ONNX_123` + - Replaces special characters: `.` → `DOT`, `:` → `COLON`, `/` → `SLASH` + +**Helper Functions**: +- `get_proto_attr()`: Provides safe protobuf attribute access with encoding checks + +--- + +## 5. Import Pipeline + +### 5.1 Complete Workflow + +``` +┌─────────────────────────────────────────────────────────┐ +│ Phase 1: Model Loading and Validation │ +├─────────────────────────────────────────────────────────┤ +│ 1. Load ONNX model (from file or protobuf) │ +│ 2. Run onnx.checker.check_model() │ +│ 3. Validate model conforms to ONNX spec │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 2: Shape Inference │ +├─────────────────────────────────────────────────────────┤ +│ 1. Run symbolic shape inference │ +│ 2. Compute concrete shapes where possible │ +│ 3. Create symbolic dimensions for dynamic shapes │ +│ 4. Auto-merge conflicting symbols (optional) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 3: ONNX-Level Optimization (optional) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Apply onnxsim.simplify() │ +│ - Constant folding │ +│ - Dead code elimination │ +│ - Operator fusion │ +│ 2. Validate optimization preserves semantics │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 4: SDFG Construction │ +├─────────────────────────────────────────────────────────┤ +│ 1. Create empty SDFG with initial state │ +│ 2. Register inputs/outputs as data descriptors │ +│ 3. For each ONNX node: │ +│ a. Get node class from registry │ +│ b. Extract and convert attributes │ +│ c. Create node instance │ +│ d. Add input/output connectors │ +│ e. Create AccessNodes for data │ +│ f. Add edges with memlets │ +│ 4. Handle special cases (Constants, Identities) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 5: Weight Management │ +├─────────────────────────────────────────────────────────┤ +│ 1. Load initializers (weights/biases) from ONNX │ +│ 2. Convert to PyTorch tensors │ +│ 3. Store in self.weights dictionary │ +│ 4. Create corresponding DaCe arrays (non-transient) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 6: Output Handling │ +├─────────────────────────────────────────────────────────┤ +│ 1. Promote scalars to arrays (CPU only) │ +│ 2. Create return arrays (__return, __return_0, etc.) │ +│ 3. Add copy-out state for outputs │ +│ 4. Fuse states for efficiency │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 7: GPU Transformation (if cuda=True) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Apply sdfg.apply_gpu_transformations() │ +│ 2. Convert memory to GPU_Global storage │ +│ 3. Add GPU kernel launch infrastructure │ +└──────────────────────┬──────────────────────────────────┘ + ▼ + SDFG with ONNX Library Nodes +``` + +### 5.2 Node Construction Details + +For each ONNX operation in the graph: + +**Step 1: Operation Lookup** +```python +if not has_onnx_node(node.op_type): + raise ValueError(f"Unsupported operation: {node.op_type}") +``` + +**Step 2: Attribute Extraction** +```python +attributes = {attr.name: convert_attribute_proto(attr) + for attr in node.attribute} +``` + +**Step 3: Node Class Retrieval** +```python +node_class = get_onnx_node(node.op_type, model_opset_version) +``` + +**Step 4: Instance Creation** +```python +dace_node = node_class(name=node.name, **attributes) +``` + +**Step 5: Connector and Edge Creation** +```python +for input_param in node_class.schema.inputs: + # Validate parameter type (Single/Optional/Variadic) + # Create or reuse AccessNode + # Add connector to operation node + # Create Memlet edge with full array semantics +``` + +### 5.3 Special Handling + +- **Constants**: Directly added to weights, no node created +- **Identities**: Can be elided during optimization +- **Variadic Parameters**: Use naming convention `param_name__index` +- **Optional Parameters**: Checked for presence, skipped if absent + +--- + +## 6. Shape Inference System + +### 6.1 Purpose and Motivation + +ONNX models often have **dynamic shapes** where tensor dimensions depend on runtime inputs: +- Batch size: Variable number of samples +- Sequence length: Variable-length sequences (NLP) +- Image dimensions: Variable-size images + +Shape inference computes tensor shapes either symbolically or concretely for all intermediate tensors in the model. + +### 6.2 Integration + +Shape inference uses `onnxruntime.tools.symbolic_shape_infer.SymbolicShapeInference` from the ONNX Runtime library. + +Called during model import: +```python +from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference +model = SymbolicShapeInference.infer_shapes(model, auto_merge=auto_merge) +``` + +### 6.3 Capabilities + +**Symbolic Dimensions**: +```python +# Input shape: [batch_size, 3, 224, 224] +# After Conv: [batch_size, 64, 112, 112] +# After Pool: [batch_size, 64, 56, 56] +``` + +**Concrete Evaluation**: +```python +# Known: kernel_size=3, stride=2, padding=1, input_size=224 +# Computed: output_size = (224 + 2*1 - 3) / 2 + 1 = 112 +``` + +**Broadcasting**: +```python +# Shape A: [batch, 256, 1, 1] +# Shape B: [batch, 256, 7, 7] +# Result: [batch, 256, 7, 7] +``` + +**Auto-Merge** (optional): +```python +# Before: tensor_0: [batch_0, seq_len_0] +# tensor_1: [batch_1, seq_len_1] +# After: tensor_0: [batch, seq_len] +# tensor_1: [batch, seq_len] +``` + +### 6.4 Implementation Details + +Shape inference uses the **ONNX Runtime implementation** (`onnxruntime.tools.symbolic_shape_infer`) which provides: + +- Helper functions for dimension extraction and axis handling +- `SymbolicShapeInference` class with per-operation rules +- Sympy-based symbolic computation +- Integration with ONNX's native shape inference +- Special handling for complex operations (Reshape, Transpose, Concat) + +### 6.5 DaCe Integration + +Symbolic dimensions are added to the SDFG symbol table: +```python +for dim_name in symbolic_dimensions: + sdfg.add_symbol(dim_name, dace.int64) +``` + +At runtime, DaCe infers symbol values from input shapes: +```python +symbols = {} +if 'batch_size' in sdfg.symbols: + symbols['batch_size'] = input_tensor.shape[0] +``` + +--- + +## 7. Implementation Strategies + +### 7.1 The ONNXForward Interface + +**Location**: [forward_implementation_abc.py](forward_implementation_abc.py) + +```python +@make_registry +class ONNXForward(abc.ABC): + """Abstract base for ONNX operation implementations.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, + sdfg: SDFG) -> bool: + """Check if implementation is applicable.""" + return True + + @staticmethod + @abc.abstractmethod + def forward(node: ONNXOp, state: SDFGState, + sdfg: SDFG) -> Union[Node, SDFG]: + """Expand node to DaCe constructs.""" + ... +``` + +### 7.2 Implementation Types + +#### 1. Pure Implementations + +**Location**: Implementations are organized across multiple files in [op_implementations/](op_implementations/): +- `elementwise_ops.py` - Element-wise operations (Add, Mul, Div, etc.) +- `reduction_ops.py` - Reduction operations (ReduceMean, ReduceSum, etc.) +- `array_ops.py` - Array operations (Concat, Gather, Reshape, etc.) +- `linalg_ops.py` - Linear algebra operations (MatMul, Gemm, etc.) +- `normalization_ops.py` - Normalization operations (BatchNorm, LayerNorm, etc.) +- `image_ops.py` - Image operations (Conv, Pool, etc.) + +**Purpose**: Provides reference implementations focused on correctness + +**Characteristics**: +- Written in Python/NumPy style +- Automatically parsed via the DaCe Python frontend +- Semantically correct according to ONNX specifications +- May not be optimally performant until further transformations are applied + +**Implementation Pattern**: +```python +@python_pure_op_implementation +def Relu(X: dace.float32[H, W]): + """Pure implementation of ReLU activation.""" + return np.maximum(X, 0) +``` + +**Process**: +1. Decorator creates an `ONNXForward` subclass +2. Function is parsed via the DaCe Python frontend +3. Converted to SDFG with maps and tasklets +4. Result: Efficient parallel code generation + +#### 2. Optimized Implementations + +**Location**: [op_implementations/img_op_implementations.py](op_implementations/img_op_implementations.py) + +**Purpose**: Provides performance-optimized implementations for specific operations + +**Examples**: +- `Conv`: Optimized convolution with im2col or Winograd +- `MaxPool/AveragePool`: Efficient pooling operations +- `BatchNormalization`: Fused batch normalization + +**Characteristics**: +- Hand-crafted SDFG construction +- May use library calls (BLAS, cuDNN) +- Optimized for specific hardware/configurations + +#### 3. Hardware-Specific Implementations + +**Concept**: Implementations optimized for specific hardware + +**Examples** (potential): +- `cuDNN` implementations for GPU (Conv, Pool, BatchNorm) +- `MKL-DNN` implementations for CPU +- `FPGA` implementations for reconfigurable hardware + +**Selection via Applicability**: +```python +@op_implementation(op="Conv", name="cudnn") +class CuDNNConv(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + return sdfg.gpu and has_cudnn() +``` + +### 7.3 Implementation Selection + +**Process**: + +1. Query the registry for the operation's implementations +2. Filter by applicability: `forward_can_be_applied()` +3. Prefer user-specified implementation (if set) +4. Fall back to the default implementation +5. Expand the node using the selected implementation + +**Priority Order**: +1. User-specified implementation (node property) +2. First applicable implementation (by registration order) +3. Default implementation (usually "pure") + +### 7.4 Common Implementation Patterns + +#### Pattern A: Pure Python with Decorator + +```python +@python_pure_op_implementation +def Softmax(X: dace.float32[N, M], axis: int = -1): + """Softmax activation function.""" + exp_x = np.exp(X - np.max(X, axis=axis, keepdims=True)) + return exp_x / np.sum(exp_x, axis=axis, keepdims=True) +``` + +#### Pattern B: Manual SDFG Construction + +```python +@op_implementation(op="MatMul", name="blas") +class BLASMatMul(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Create nested SDFG + nsdfg = dace.SDFG(f"{node.label}_matmul") + nstate = nsdfg.add_state() + + # Use BLAS library node + from dace.libraries.blas import MatMul + matmul_node = MatMul("matmul") + + # Connect inputs/outputs + # ... + + return nsdfg +``` + +#### Pattern C: Library Call Integration + +```python +@op_implementation(op="Conv", name="optimized") +class OptimizedConv(ONNXForward): + @staticmethod + def forward(node, state, sdfg): + # Leverage existing DaCe library nodes + from dace.libraries.standard import Conv2D + + # Convert ONNX semantics to library call + conv_node = Conv2D(...) + + # Return library node (further expanded by DaCe) + return conv_node +``` + +### 7.5 Implementation Utilities + +**Location**: [op_implementations/utils.py](op_implementations/utils.py) + +**Key Functions**: + +- `@op_implementation(op, name)`: Register implementation with registry +- `@python_pure_op_implementation`: Create implementation from Python function +- `program_for_node()`: Convert Python function to nested SDFG +- `empty_sdfg_for_node()`: Create empty nested SDFG template + +--- + +## 8. Key Algorithms + +### 8.1 Dynamic Node Class Generation + +**Algorithm**: Creates Python classes at import time + +``` +For each ONNX operation in onnx.defs.get_all_schemas(): + 1. Extract OpSchema from ONNX + 2. Convert to ONNXSchema (DaCe representation) + 3. For each version of the operation: + a. Generate class name: ONNX{OpName}_{Version} + b. Create properties from attributes: + - Map ONNX types to DaCe property types + - Set defaults and required flags + c. Generate __init__ constructor: + - Validate required attributes provided + - Convert types (e.g., StringLiteral → str) + - Set up connectors for parameters + d. Generate documentation from schema + e. Create class with type(): + cls = type(cls_name, (ONNXOp,), attrs) + f. Register as DaCe library node: + cls = dace.library.node(cls) + g. Link implementations: + - Query ONNXForward.extensions() + - Create ExpandTransformation wrappers + - Register with node class + h. Store in registry: + _ONNX_OPS[op_name][version] = cls + 4. Export latest version to module: + globals()[f"ONNX{OpName}"] = latest_version +``` + +**Result**: 100+ operation classes generated automatically, ready for use + +### 8.2 Schema-Based Validation + +**Algorithm**: Validates node construction + +``` +When creating ONNX node instance: + 1. Check required attributes provided: + missing = required_attrs - provided_attrs + if missing: raise ValueError(...) + + 2. Validate connector usage: + For each edge connected to node: + a. Determine parameter (input/output) + b. Check parameter type (Single/Optional/Variadic) + c. Validate connector naming: + - Single/Optional: exact name + - Variadic: name__index format + d. Verify edge data type matches constraints + + 3. Type constraint checking: + For each connector with type constraint: + a. Get connector data type + b. Look up constraint allowed types + c. Verify type in allowed set + d. If not: raise validation error +``` + +### 8.3 Runtime Symbol Inference + +**Algorithm**: Infers symbolic dimension values from inputs + +``` +When executing ONNXModel: + 1. Collect all symbols in SDFG: + symbols = sdfg.free_symbols + + 2. For each input tensor: + For each dimension in tensor.shape: + if dimension_name in symbols: + inferred_symbols[dimension_name] = dimension_value + + 3. Verify all required symbols inferred: + missing = symbols - inferred_symbols.keys() + if missing: raise ValueError(...) + + 4. Pass symbols to compiled SDFG: + result = compiled_sdfg(inputs..., **inferred_symbols) +``` + +### 8.4 Type Conversion Pipeline + +**Algorithm**: Converts between type systems + +``` +ONNX Type → DaCe Type: + 1. Extract ONNX type enum (e.g., TensorProto.FLOAT) + 2. Look up in cached mapping: + dace_type = onnx_to_dace_type_map[onnx_type] + 3. Return DaCe typeclass (e.g., dace.float32) + +DaCe Type → NumPy Type: + 1. Get DaCe typeclass + 2. Extract numpy_dtype property + 3. Return numpy dtype (e.g., np.float32) + +NumPy Type → PyTorch Type: + 1. Look up in numpy_to_torch_dtype_dict + 2. Return torch dtype (e.g., torch.float32) +``` + +--- + +## 9. Extension Points + +### 9.1 Adding New ONNX Operations + +If an ONNX operation is not yet supported, you can add it by creating an implementation: + +**Step 1: Create Implementation Class** + +```python +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import op_implementation + +@op_implementation(op="CustomOp", name="pure") +class CustomOpImplementation(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + # Check if this implementation is applicable + return True + + @staticmethod + def forward(node, state, sdfg): + # Create nested SDFG for operation + # ... + return nested_sdfg +``` + +**Step 2: Register Implementation** + +The `@op_implementation` decorator automatically registers the implementation with the ONNXForward registry. + +**Step 3: Use in Models** + +The operation will now be available when importing ONNX models that use it. + +### 9.2 Custom Implementations for Existing Operations + +Override the default implementation with a custom one: + +```python +@op_implementation(op="Conv", name="my_optimized_conv") +class MyOptimizedConv(ONNXForward): + @staticmethod + def forward_can_be_applied(node, state, sdfg): + # Only apply for specific configurations + return (node.kernel_shape == [3, 3] and + node.stride == [1, 1]) + + @staticmethod + def forward(node, state, sdfg): + # Custom optimized implementation + # ... +``` + +**Selection**: Set `node.default_implementation = "my_optimized_conv"` or allow DaCe to select automatically based on applicability. diff --git a/dace/libraries/onnx/op_implementations/__init__.py b/dace/libraries/onnx/op_implementations/__init__.py new file mode 100644 index 0000000000..256b3a444a --- /dev/null +++ b/dace/libraries/onnx/op_implementations/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .utils import * +from .common import * +from .elementwise_ops import * +from .reduction_ops import * +from .normalization_ops import * +from .array_ops import * +from .linalg_ops import * +from .image_ops import * +from .img_op_implementations import * +from .criteria_implementations import * diff --git a/dace/libraries/onnx/op_implementations/array_ops.py b/dace/libraries/onnx/op_implementations/array_ops.py new file mode 100644 index 0000000000..c918d9b6b3 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/array_ops.py @@ -0,0 +1,681 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Array and Tensor Manipulation Operations for ONNX in DaCe. + +This module provides pure DaCe implementations for ONNX array/tensor manipulation +operations. These operations handle shape manipulation, slicing, and other array transformations. + +The module contains: +- Shape manipulation operations (Reshape, Flatten, Squeeze, Unsqueeze, Expand) +- Slicing and indexing operations (Slice, SliceAllConstant, Gather) +- Concatenation and splitting operations (Concat, Split) +- Transposition operations (Transpose, EinsumTranspose) +- Shape query operations (Shape) + +Each implementation follows the ONNX specification and is designed to be: +- Semantically correct according to ONNX standards +- Efficient when converted to DaCe SDFGs +""" + +import copy +from math import prod +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState, subsets +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import (empty_sdfg_for_node, op_implementation, program_for_node, + python_pure_op_implementation) +from dace.transformation.onnx import constant_folding +from dace.transformation.onnx.replacement import onnx_constant_or_none +from dace.libraries.onnx import converters + +# ============================================================================== +# Concatenation Operations +# ============================================================================== + + +@op_implementation(op="Concat", name="pure") +class PureConcat(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axis = node.axis + + num_inputs = len(state.in_edges(node)) + + def inp_name(i): + return f"inputs__{i}" + + out_name = "concat_result" + + inp_data = [in_desc_with_name(node, state, sdfg, inp_name(i)) for i in range(num_inputs)] + out_data = out_desc_with_name(node, state, sdfg, out_name) + + nsdfg = dace.SDFG(node.label) + + inp_data_descs = [copy.deepcopy(desc) for desc in inp_data] + for i, desc in enumerate(inp_data_descs): + desc.transient = False + nsdfg.add_datadesc(inp_name(i), desc) + out_data_desc = copy.deepcopy(out_data) + out_data_desc.transient = False + nsdfg.add_datadesc(out_name, out_data_desc) + + inp_shapes = [d.shape for d in inp_data] + out_shape = out_data_desc.shape + + nstate = nsdfg.add_state() + out_write = nstate.add_write(out_name) + + for inp_idx in range(num_inputs): + inp_read = nstate.add_read(inp_name(inp_idx)) + + tasklet = nstate.add_tasklet( + f'concat_{inp_idx}', + {'inp': inp_data_descs[inp_idx].dtype}, + {'out': out_data_desc.dtype}, + "out = inp", + ) + + map_entry, map_exit = nstate.add_map(f"concat_map_{inp_idx}", { + f"i{i}": f"0:{s}" + for i, s in enumerate(inp_shapes[inp_idx]) + }) + + inp_access = [f'i{i}' for i, _ in enumerate(inp_shapes[inp_idx])] + inp_access_str = ", ".join(inp_access) + inp_memlet = dace.Memlet(f"{inp_name(inp_idx)}[{inp_access_str}]") + + stack_idx_offset = "" + for i in range(inp_idx): + stack_idx_offset += f" + ({inp_shapes[i][axis]})" + + out_access = [f'i{i}' for i, _ in enumerate(out_shape)] + if stack_idx_offset: + out_access[axis] += stack_idx_offset + out_access_str = ", ".join(out_access) + out_memlet = dace.Memlet(f"{out_name}[{out_access_str}]") + + nstate.add_memlet_path(inp_read, map_entry, tasklet, memlet=inp_memlet, dst_conn="inp") + nstate.add_memlet_path(tasklet, map_exit, out_write, memlet=out_memlet, src_conn="out") + + return nsdfg + + +# ============================================================================== +# Shape Manipulation Operations - Unsqueeze +# ============================================================================== + + +@op_implementation(op="Unsqueeze", name="pure") +class PureUnsqueeze(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + # Get input/output descriptors + expanded_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "expanded")) + + def prog(data, expanded): + expanded[:] = np.reshape(data, expanded_desc.shape) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Shape Manipulation Operations - Squeeze +# ============================================================================== + + +@op_implementation(op="Squeeze", name="pure") +class PureSqueeze(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + squeezed_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "squeezed")) + + def prog(data, squeezed): + squeezed[:] = np.reshape(data, squeezed_desc.shape) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Shape Manipulation Operations - Expand +# ============================================================================== + + +@op_implementation(op="Expand", name="pure") +class PureExpand(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + shape = out_desc_with_name(node, state, sdfg, "output").shape + + def prog(input, output): + output = np.broadcast_to(input, shape) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="Expand", name="pure") +class PureExpand(ONNXForward): + """ Handle no-op case for Expand """ + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return iterables_equal( + in_desc_with_name(node, state, sdfg, "input").shape, + out_desc_with_name(node, state, sdfg, "output").shape) + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + constant_folding.remove_node_and_computation(sdfg, state, node, "shape") + + def prog(input, output): + output[:] = input + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================== +# Transposition Operations +# ============================================================================== + + +@python_pure_op_implementation( + perm=lambda node, data: node.perm if node.perm is not None else list(reversed(range(len(data.shape))))) +def Transpose(data, transposed): + transposed[:] = np.transpose(data, axes=perm) + + +@op_implementation(op="Transpose", name="einsum") +class EinsumTranspose(ONNXForward): + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + perm = node.perm + input_desc = in_desc_with_name(node, state, sdfg, "data") + output_desc = out_desc_with_name(node, state, sdfg, "transposed") + + letters = [chr(ord('z') - i) for i in range(26)] + input_letters = "".join(letters[i] for i, _ in enumerate(input_desc.shape)) + output_letters = "".join(letters[i] for i in perm) + equation_str = f"{input_letters}->{output_letters}" + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + einsum_node: onnx_op.ONNXOp = ONNXEinsum(node.label + "_einsum_expansion", equation=equation_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + nsdfg.add_datadesc("data", copy.deepcopy(input_desc)) + nsdfg.add_datadesc("transposed", copy.deepcopy(output_desc)) + nsdfg.arrays["data"].transient = False + nsdfg.arrays["transposed"].transient = False + + nstate.add_edge(nstate.add_read("data"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("data")) + nstate.add_edge(einsum_node, "Output", nstate.add_write("transposed"), None, + nsdfg.make_array_memlet("transposed")) + + return nsdfg + + +# ============================================================================== +# Reshape and Flatten Operations +# ============================================================================== + + +@python_pure_op_implementation(shape=lambda reshaped: reshaped.shape, + allowzero=lambda node: getattr(node, 'allowzero', 0)) +def Reshape(data, reshaped): + # If allowzero is 0 (default), we use numpy's reshape which doesn't allow zeros + # If allowzero is 1, we need to handle zeros in the shape tensor + if allowzero == 0: + reshaped[:] = np.reshape(data, shape) + else: + # For allowzero=1, we need to handle zeros in the shape tensor + # This means we need to preserve the original dimension size when a zero is encountered + new_shape = list(shape) + for i, dim in enumerate(new_shape): + if dim == 0: + new_shape[i] = data.shape[i] + reshaped[:] = np.reshape(data, new_shape) + + +@python_pure_op_implementation(shape=lambda input, node: [prod(input.shape[:node.axis]), prod(input.shape[node.axis:])]) +def Flatten(input, output): + output[:] = input.reshape(shape) + + +# ============================================================================== +# Slicing Operations +# ============================================================================== + + +@op_implementation(op="Slice", name="pure") +class PureSlice(ONNXForward): + ''' + Slice expansion + ''' + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + # Check that all the inputs (even the optional ones) are present and constant + + if not hasattr(sdfg, "_parent_onnx_model"): + return False + + constant_starts = in_edge_with_name(node, state, "starts").src.data in sdfg._parent_onnx_model.clean_weights + + if not constant_starts: + return False + if in_edge_with_name(node, state, "ends").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + + # optional inputs + is_axes_present = True + try: + if in_edge_with_name(node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + is_steps_present = True + try: + if in_edge_with_name(node, state, "steps").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_steps_present = False + + # Current constraints: axes and steps must be explict. Axes must be zero and steps must be 1 + if not is_axes_present or not is_steps_present: + return False + + step = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "steps").src.data].numpy()[0] + axis = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy()[0] + + if step != 1 or axis != 0: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + start = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "starts").src.data].numpy()[0] + end = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "ends").src.data].numpy()[0] + + output_shape = out_desc_with_name(node, state, sdfg, "output").shape + if end == np.iinfo(np.int64).max: + # Pytorch exporter artifact + end = start + output_shape[0] + + def prog(data, output): + tmp = data[start:end:1, :] + # We need reshape to avoid Invalid Edge errors + output[:] = np.reshape(tmp, output.shape) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="Slice", name="pure") +class PureSliceAllConstant(ONNXForward): + + @staticmethod + def _get_constant(conn: str, node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG): + try: + srcnode = next(state.in_edges_by_connector(node, conn)).src + except StopIteration: + # Return default values + if conn == "steps": + return 1 + return None + # Scalar copied to GPU + if 'gpu_' in srcnode.data: + srcnode = state.predecessors(srcnode)[0] + return onnx_constant_or_none(sdfg, srcnode) + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + for inconn in ("axes", "ends", "starts", "steps"): + if PureSliceAllConstant._get_constant(inconn, node, state, sdfg) is None: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = PureSliceAllConstant._get_constant('axes', node, state, sdfg) + ends = PureSliceAllConstant._get_constant('ends', node, state, sdfg) + starts = PureSliceAllConstant._get_constant('starts', node, state, sdfg) + steps = PureSliceAllConstant._get_constant('steps', node, state, sdfg) + + constant_folding.remove_node_and_computation(sdfg, state, node, "axes") + constant_folding.remove_node_and_computation(sdfg, state, node, "ends") + constant_folding.remove_node_and_computation(sdfg, state, node, "starts") + constant_folding.remove_node_and_computation(sdfg, state, node, "steps") + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + idesc = in_desc_with_name(node, state, sdfg, "data") + odesc = out_desc_with_name(node, state, sdfg, "output") + nsdfg.add_datadesc("data", copy.deepcopy(idesc)) + nsdfg.add_datadesc("output", copy.deepcopy(odesc)) + nsdfg.arrays["data"].transient = False + nsdfg.arrays["output"].transient = False + + if not isinstance(axes, (tuple, list)): + axes = [axes] + ends = [ends] + starts = [starts] + steps = [steps] + + # Set up slicing memlet + rng = [(0, s - 1, 1) for s in idesc.shape] + for axis, start, end, step in zip(axes, starts, ends, steps): + s = idesc.shape[axis] + if end > s: + end = s + rng[axis] = (start, end - 1, step) + + sbs = subsets.Range(rng) + osbs = subsets.Range.from_array(odesc) + + # Make copy / view + rnode = nstate.add_read("data") + wnode = nstate.add_write("output") + + nstate.add_nedge(rnode, wnode, dace.Memlet(data="data", subset=sbs, other_subset=osbs)) + + return nsdfg + + +# ============================================================================== +# Split Operations +# ============================================================================== + + +@op_implementation(op="Split", name="pure") +class SplitPure(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + from dace.transformation.onnx.replacement import onnx_constant_or_none + + # Check if we have either split input or num_outputs attribute + has_split_input = len(list(state.in_edges_by_connector(node, "split"))) > 0 + has_num_outputs = hasattr(node, 'num_outputs') + + if not (has_split_input or has_num_outputs): + return False + + # If split input is provided, it must be a constant + if has_split_input: + split_node = next(state.in_edges_by_connector(node, "split")).src + if not onnx_constant_or_none(sdfg, split_node): + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.transformation.onnx.replacement import onnx_constant_or_none + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + split_dim = node.axis + idesc = in_desc_with_name(node, state, sdfg, "input") + nsdfg.add_datadesc("input", copy.deepcopy(idesc)) + nsdfg.arrays["input"].transient = False + + rnode = nstate.add_read("input") + + # Get split sizes either from input or compute from num_outputs + if len(list(state.in_edges_by_connector(node, "split"))) > 0: + # Get split sizes from input tensor + split_node = next(state.in_edges_by_connector(node, "split")).src + split_sizes = onnx_constant_or_none(sdfg, split_node) + if split_sizes is None: + raise ValueError("Split sizes must be constant") + + # Add split input as a data descriptor + split_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, "split")) + split_desc.transient = False + nsdfg.add_datadesc("split", split_desc) + split_read = nstate.add_read("split") + else: + # Compute split sizes from num_outputs + num_outputs = node.num_outputs + total_size = idesc.shape[split_dim] + base_size = total_size // num_outputs + remainder = total_size % num_outputs + split_sizes = [base_size + (1 if i < remainder else 0) for i in range(num_outputs)] + + # Verify split sizes + if sum(split_sizes) != idesc.shape[split_dim]: + raise ValueError( + f"Sum of split sizes ({sum(split_sizes)}) must equal dimension size ({idesc.shape[split_dim]})") + + offset = 0 + for i, odim in enumerate(split_sizes): + # Set up new node shape and memlet + new_shape = list(idesc.shape) + new_shape[split_dim] = odim + rng = subsets.Range([(0, s - 1, 1) if j != split_dim else (offset, offset + odim - 1, 1) + for j, s in enumerate(new_shape)]) + offset += odim + + # Set up data descriptor + oname = f"outputs__{i}" + odesc = copy.deepcopy(out_desc_with_name(node, state, sdfg, oname)) + odesc.transient = False + nsdfg.add_datadesc(oname, odesc) + wnode = nstate.add_write(oname) + + # Perform copy (view) + nstate.add_nedge(rnode, wnode, + dace.Memlet(data="input", subset=rng, other_subset=subsets.Range.from_array(odesc))) + + return nsdfg + + +# ============================================================================== +# Shape Query Operations +# ============================================================================== + + +@op_implementation(op="Shape", name="pure") +class PureShape(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + data_desc = in_desc_with_name(node, state, sdfg, "data") + + try: + np.array(data_desc.shape, np.int64) + except Exception: + # this happens if the shape is symbolic, for example + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + data_desc = in_desc_with_name(node, state, sdfg, "data") + shape_val = np.array(data_desc.shape, np.int64) + + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + nsdfg.add_datadesc( + "data", + copy.deepcopy(data_desc), + ) + nsdfg.arrays["data"].transient = False + nsdfg.add_array("shape", shape_val.shape, dtype=dace.int64) + s = nstate.add_write("shape") + + for i, v in enumerate(shape_val): + tasklet = nstate.add_tasklet("write_shape", {}, {'shape_scalar': dace.int64}, f"shape_scalar = {v}") + nstate.add_edge(tasklet, "shape_scalar", s, None, dace.Memlet("shape[{}]".format(i))) + + return nsdfg + + +# ============================================================================== +# Gather Operations +# ============================================================================== + + +@op_implementation(op="Gather", name="pure") +class PureGather(ONNXForward): + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + # To understand this operator, read the docs for np.take. + # The ONNX docs are not easy to understand (and are incorrect in opset 11) + + nsdfg, nstate, _, _ = empty_sdfg_for_node(sdfg, state, node, add_access_nodes=False) + out_desc = out_desc_with_name(node, state, sdfg, "output") + out_shape = out_desc.shape + idx_desc = in_desc_with_name(node, state, sdfg, "indices") + idx_shape = idx_desc.shape + data_shape = in_desc_with_name(node, state, sdfg, "data").shape + + # FIXME: we can sometimes generate views + + # Generate a copy kernel that loops over every element in the output + # and read the correct element according to the indices + + axis = node.axis + + map_ranges = [(f"i{i}", f"0:{s}") for i, s in enumerate(out_shape)] + # the map ranges can be partitioned into two parts. + # the first part is the range over the indices, the second part is the + # range over the data + if isinstance(idx_desc, dace.data.Scalar): + # handle the edgecase here because the shape of a scalar in dace is + # (1,) not () + idx_len = 0 + else: + idx_len = len(idx_shape) + map_ranges_indices = map_ranges[axis:axis + idx_len] + map_ranges_data = map_ranges[:axis] + map_ranges[axis + idx_len:] + + # compute the indexing expressions + fst = lambda x: x[0] + output_idx_str = 'output[' + ', '.join(map(fst, map_ranges)) + ']' + # the memlet string used to read data, which reads the whole axis + data_memlet_elems = list(map(fst, map_ranges_data)) + data_memlet_elems.insert(axis, f'0:{data_shape[axis]}') + + data_memlet_str = 'data[' + ', '.join(data_memlet_elems) + ']' + + indices_idx_str = 'indices' + if map_ranges_indices: + indices_idx_str += '[' + ', '.join(map(fst, map_ranges_indices)) + ']' + else: + indices_idx_str += '[0]' + + tasklet, me, mx = nstate.add_mapped_tasklet(node.label + "_tasklet", + map_ranges=map_ranges, + inputs={ + "__data": dace.Memlet(data_memlet_str), + "idx": dace.Memlet(indices_idx_str), + }, + code=f"__output = __data[idx]", + outputs={"__output": dace.Memlet(output_idx_str)}, + external_edges=True) + + # required to make underlying code to see it as a pointer and enable index-based access + # even if the data contains just a single element + tasklet.in_connectors["__data"] = dace.pointer(out_desc.dtype) + + return nsdfg + + +# ============================================================================== +# Utility Operations +# ============================================================================== + + +@python_pure_op_implementation +def Where(condition, X, Y, output): + output[:] = np.where(condition, X, Y) + + +@python_pure_op_implementation +def Identity(input, output): + output[:] = input + + +@op_implementation(op="Cast", name="pure") +class PureCast(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + + if (in_desc_with_name(node, state, sdfg, "input").dtype == out_desc_with_name(node, state, sdfg, + "output").dtype): + return True + + target_type = node.to + try: + converters.onnx_tensor_type_to_typeclass(target_type) + except ValueError: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + input_desc = in_desc_with_name(node, state, sdfg, "input") + output_desc = out_desc_with_name(node, state, sdfg, "output") + if (input_desc.dtype == output_desc.dtype): + + def prog(input, output): + output[:] = input + + return program_for_node(prog, sdfg, state, node) + else: + + nsdfg, nstate, _, _ = empty_sdfg_for_node(sdfg, state, node, add_access_nodes=False) + + shape = out_desc_with_name(node, state, sdfg, "output").shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + tasklet, _, _ = nstate.add_mapped_tasklet(node.label + "_tasklet", + map_ranges=map_ranges, + inputs={f"__input": dace.Memlet(f"input[{index_str}]")}, + code=f"__output = __input", + outputs={"__output": dace.Memlet(f"output[{index_str}]")}, + external_edges=True) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/common.py b/dace/libraries/onnx/op_implementations/common.py new file mode 100644 index 0000000000..fcda74fc27 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/common.py @@ -0,0 +1,11 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Common utilities and helper functions for ONNX pure implementations. +""" + + +def iterables_equal(a, b) -> bool: + """ Return whether the two iterables ``a`` and ``b`` are equal. """ + if len(a) != len(b): + return False + return all(x == y for x, y in zip(a, b)) diff --git a/dace/libraries/onnx/op_implementations/criteria_implementations.py b/dace/libraries/onnx/op_implementations/criteria_implementations.py new file mode 100644 index 0000000000..42eccd6bca --- /dev/null +++ b/dace/libraries/onnx/op_implementations/criteria_implementations.py @@ -0,0 +1,90 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Union + +import numpy as np + +import dace +from dace import SDFG, SDFGState, nodes as nd + +from dace.libraries.onnx.op_implementations.utils import op_implementation, program_for_node +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.forward_implementation_abc import ONNXForward + +from dace.sdfg.utils import in_desc_with_name + + +@op_implementation(op="SoftmaxCrossEntropyLoss", name="pure") +class PureSoftmaxCrossEntropyLoss(ONNXForward): + """Pure implementation of SoftmaxCrossEntropyLoss operation.""" + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The SoftmaxCrossEntropyLoss ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + # Softmax is weird in opset 11, so let's stick to 2D for now + if len(in_desc_with_name(node, state, sdfg, "scores").shape) != 2: + return False + + if node.ignore_index is not None and node.ignore_index >= 0: + return False + + # The weights and log_prob arguments are optional + # We don't support them in this implementation + if 'weights' in node.in_connectors: + return False + if 'log_prob' in node.out_connectors: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> Union[nd.Node, SDFG]: + """Generate the forward pass implementation for SoftmaxCrossEntropyLoss. + + :param node: The SoftmaxCrossEntropyLoss ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the SoftmaxCrossEntropyLoss operation. + """ + + if node.reduction == 'mean': + + def reduction(x): + return np.mean(x) + elif node.reduction == 'none': + + def reduction(x): + return x + elif node.reduction == 'sum': + + def reduction(x): + return np.sum(x) + else: + raise ValueError("Unsupported reduction: {}".format(node.reduction)) + reduction = dace.program(reduction) + + # This implementation doesn't use ONNX LogSoftmax, and thus saves the + # final sum reduction by just grabbing the label scores directly, and + # skipping the computation of log softmax for all non-label scores + def prog(scores, labels, output): + # Extract the scores for the labels + + # Compute the log softmax normalization + maximum = np.maximum.reduce(scores, axis=1, keepdims=True) + max_sub = scores - maximum + exponent = np.exp(max_sub) + sum = np.add.reduce(exponent, axis=1) + log_sum = np.log(sum) + + # Compute the loss values + label_exponents = max_sub[:, labels] + losses = log_sum - label_exponents + output[:] = reduction(losses) + + return program_for_node(prog, sdfg, state, node) diff --git a/dace/libraries/onnx/op_implementations/elementwise_ops.py b/dace/libraries/onnx/op_implementations/elementwise_ops.py new file mode 100644 index 0000000000..bf66a868d5 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/elementwise_ops.py @@ -0,0 +1,212 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Elementwise and mathematical ONNX operations. + +This module contains pure implementations of elementwise mathematical operations including: +- Basic arithmetic: Add, Sub, Mul, Div, Pow +- Unary math functions: Log, Exp, Sqrt, Sin, Cos, Tanh, Erf, Neg, Reciprocal +- Activation functions: Relu, LeakyRelu, Sigmoid, Softplus +- Utility operations: Clip + +All operations support broadcasting where applicable. +""" + +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import (op_implementation, out_desc_with_name, program_for_node, + python_pure_op_implementation) +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.transformation.onnx.replacement import onnx_constant_or_none + +# ============================================================================ +# Unary Mathematical Operations +# ============================================================================ + + +@python_pure_op_implementation +def Log(input, output): + """ONNX Log operation implementation. + + Computes the natural logarithm of the input tensor element-wise. + + :param input: Input tensor of any numeric type. + :param output: Output tensor with the same shape and type as input. + """ + output[:] = np.log(input) + + +@python_pure_op_implementation +def Exp(input, output): + """ONNX Exp operation implementation. + + Computes the exponential of the input tensor element-wise. + + :param input: Input tensor of any numeric type. + :param output: Output tensor with the same shape and type as input. + """ + output[:] = np.exp(input) + + +@python_pure_op_implementation +def Sqrt(X, Y): + """ONNX Sqrt operation implementation. + + Computes the square root of the input tensor element-wise. + + :param X: Input tensor of any numeric type. + :param Y: Output tensor with the same shape and type as X. + """ + Y[:] = dace.elementwise(lambda x: sqrt(x), X) + + +@python_pure_op_implementation +def Sin(input, output): + output[:] = np.sin(input) + + +@python_pure_op_implementation +def Cos(input, output): + output[:] = np.cos(input) + + +@python_pure_op_implementation +def Tanh(input, output): + output[:] = dace.elementwise(lambda x: tanh(x), input) + + +@python_pure_op_implementation +def Erf(input, output): + output[:] = dace.elementwise(lambda x: erf(x), input) + + +@python_pure_op_implementation +def Neg(X, Y): + Y[:] = -X + + +@python_pure_op_implementation(string=lambda X: "lambda x: dace.{}(1) / x".format(X.dtype.to_string())) +def Reciprocal(X, Y): + Y[:] = dace.elementwise(string, X) + + +@python_pure_op_implementation +def Softplus(X, Y): + Y[:] = np.log(1 + np.exp(X)) + + +@python_pure_op_implementation(dtype=lambda X: X.dtype) +def Sigmoid(X, Y): + Y[:] = dace.elementwise(lambda x: dtype(1) / (dtype(1) + exp(-x)), X) + + +# ============================================================================ +# Binary Arithmetic Operations +# ============================================================================ + + +@op_implementation(op="Pow", name="pure") +class PurePow(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + # Special case for constant exponents + y_value = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "Y").src.data in sdfg._parent_onnx_model.clean_weights: + y_value = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "Y").src.data].numpy() + except ValueError: + pass + + if y_value is not None and y_value.ndim == 0: + y_value = int(y_value) + + def prog(X, Z): + Z[:] = X**y_value + + return program_for_node(prog, sdfg, state, node) + + # General case + def prog(X, Y, Z): + Z[:] = X**Y + + return program_for_node(prog, sdfg, state, node) + + +@python_pure_op_implementation +def Add(A, B, C): + C[:] = A + B + + +@python_pure_op_implementation +def Sub(A, B, C): + C[:] = A - B + + +@python_pure_op_implementation +def Mul(A, B, C): + C[:] = A * B + + +@python_pure_op_implementation +def Div(A, B, C): + C[:] = A / B + + +# ============================================================================ +# Activation Functions and Clipping +# ============================================================================ + + +@python_pure_op_implementation(cast_lambda=lambda X: "lambda x: max(x, dace.{}(0))".format(X.dtype.to_string())) +def Relu(X, Y): + Y[:] = dace.elementwise(cast_lambda, X) + + +@python_pure_op_implementation( + cast_lambda=lambda node, X: "lambda x: (max(x, dace.{dtype}(0)) + {alpha} * min(x, dace.{dtype}(0)))".format( + dtype=X.dtype.to_string(), alpha=node.alpha)) +def LeakyRelu(X, Y): + Y[:] = dace.elementwise(cast_lambda, X) + + +@op_implementation(op="Clip", name="pure") +class PureClip(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + min_node = next(state.in_edges_by_connector(node, 'min')).src + max_node = next(state.in_edges_by_connector(node, 'max')).src + # TODO other cases + return (onnx_constant_or_none(sdfg, min_node) is not None and onnx_constant_or_none(sdfg, max_node) is not None) + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + min_node = next(state.in_edges_by_connector(node, 'min')).src + max_node = next(state.in_edges_by_connector(node, 'max')).src + minval = onnx_constant_or_none(sdfg, min_node) + maxval = onnx_constant_or_none(sdfg, max_node) + + input_dtype = in_desc_with_name(node, state, sdfg, "input").dtype + minstr = f"dace.{input_dtype.to_string()}({minval})" + maxstr = f"dace.{input_dtype.to_string()}({maxval})" + + lfunc = f"lambda x: min(max(x, {minstr}), {maxstr})" + + def prog(input, output): + output[:] = dace.elementwise(lfunc, input) + + return program_for_node(prog, sdfg, state, node) diff --git a/dace/libraries/onnx/op_implementations/image_ops.py b/dace/libraries/onnx/op_implementations/image_ops.py new file mode 100644 index 0000000000..55bad5d2f8 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/image_ops.py @@ -0,0 +1,443 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Image and Signal Processing Operations for ONNX in DaCe. + +This module provides implementations for ONNX operations related to image and signal +processing, including resizing, interpolation, and related transformations. + +Operations implemented: +- Resize: Image resizing with various interpolation modes (nearest, linear, cubic) + and coordinate transformation modes +""" + +import copy +import typing + +import dace +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import op_implementation + + +@op_implementation(op="Resize", name="pure") +class PureResize(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + # Check if we have either scales or sizes (but not both) + has_scales = len(list(state.in_edges_by_connector(node, 'scales'))) > 0 + has_sizes = len(list(state.in_edges_by_connector(node, 'sizes'))) > 0 + + if has_scales == has_sizes: + return False + + # Check interpolation mode + mode = getattr(node, 'mode', 'nearest') + if mode is not None and mode not in ['nearest', 'linear', 'cubic']: + return False + + # Check nearest mode if using nearest interpolation + if mode == 'nearest': + nearest_mode = getattr(node, 'nearest_mode', 'round_prefer_floor') + if nearest_mode is not None and nearest_mode not in [ + 'round_prefer_floor', 'round_prefer_ceil', 'floor', 'ceil' + ]: + return False + + # Check coordinate transformation mode + coord_mode = getattr(node, 'coordinate_transformation_mode', 'half_pixel') + if coord_mode is not None and coord_mode not in [ + 'half_pixel', 'half_pixel_symmetric', 'pytorch_half_pixel', 'align_corners', 'asymmetric', + 'tf_crop_and_resize' + ]: + return False + + # For tf_crop_and_resize, roi must be present + if coord_mode == 'tf_crop_and_resize': + has_roi = len(list(state.in_edges_by_connector(node, 'roi'))) > 0 + if not has_roi: + return False + + # Check keep_aspect_ratio_policy if using sizes + if has_sizes: + policy = getattr(node, 'keep_aspect_ratio_policy', 'stretch') + if policy is not None and policy not in ['stretch', 'not_larger', 'not_smaller']: + return False + + # Check antialias + antialias = getattr(node, 'antialias', 0) + if antialias is not None and antialias not in [0, 1]: + return False + + # Check exclude_outside + exclude_outside = getattr(node, 'exclude_outside', 0) + if exclude_outside is not None and exclude_outside not in [0, 1]: + return False + + # Check extrapolation_value + extrapolation_value = getattr(node, 'extrapolation_value', 0.0) + if extrapolation_value is not None and not isinstance(extrapolation_value, (int, float)): + return False + + # Check cubic coefficient + if mode == 'cubic': + cubic_coeff_a = getattr(node, 'cubic_coeff_a', -0.75) + if cubic_coeff_a is not None and not isinstance(cubic_coeff_a, (int, float)): + return False + + # Check axes if provided + axes = getattr(node, 'axes', None) + if axes is not None: + if not isinstance(axes, (list, tuple)): + return False + # Check for duplicate axes + if len(set(axes)) != len(axes): + return False + # Check for valid axis values + rank = len(in_desc_with_name(node, state, sdfg, 'X').shape) + for axis in axes: + if not isinstance(axis, int) or axis < -rank or axis >= rank: + return False + + # Check input shapes + x_desc = in_desc_with_name(node, state, sdfg, 'X') + rank = len(x_desc.shape) + if has_scales: + scales_desc = in_desc_with_name(node, state, sdfg, 'scales') + if len(scales_desc.shape) != 1: + return False + if len(axes) if axes is not None else rank != scales_desc.shape[0]: + return False + if has_sizes: + sizes_desc = in_desc_with_name(node, state, sdfg, 'sizes') + if len(sizes_desc.shape) != 1: + return False + if len(axes) if axes is not None else rank != sizes_desc.shape[0]: + return False + + # Check output shape + y_desc = out_desc_with_name(node, state, sdfg, 'Y') + if len(x_desc.shape) != len(y_desc.shape): + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + inp_name = 'X' + out_name = 'Y' + + nsdfg = dace.SDFG(node.label) + + # Add required input and output descriptors + inp_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, inp_name)) + inp_data_desc.transient = False + nsdfg.add_datadesc(inp_name, inp_data_desc) + + out_data_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, out_name)) + out_data_desc.transient = False + nsdfg.add_datadesc(out_name, out_data_desc) + + # Check for optional parameters + has_scales = len(list(state.in_edges_by_connector(node, 'scales'))) > 0 + has_sizes = len(list(state.in_edges_by_connector(node, 'sizes'))) > 0 + has_roi = len(list(state.in_edges_by_connector(node, 'roi'))) > 0 + + # Get axes to resize + axes = node.axes or list(range(len(inp_data_desc.shape))) + + # Convert negative axes to positive + axes = [ax if ax >= 0 else len(inp_data_desc.shape) + ax for ax in axes] + + # Add optional parameter descriptors if they exist + if has_scales: + scales_name = 'scales' + scales_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, scales_name)) + scales_data_desc.transient = False + nsdfg.add_datadesc(scales_name, scales_data_desc) + + if has_sizes: + sizes_name = 'sizes' + sizes_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, sizes_name)) + sizes_data_desc.transient = False + nsdfg.add_datadesc(sizes_name, sizes_data_desc) + + if has_roi: + roi_name = 'roi' + roi_data_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, roi_name)) + roi_data_desc.transient = False + nsdfg.add_datadesc(roi_name, roi_data_desc) + + num_dims = len(inp_data_desc.shape) + + # setup inner SDFG + nstate = nsdfg.add_state() + + inp_read = nstate.add_read(inp_name) + out_write = nstate.add_write(out_name) + + # Add reads for optional parameters + tasklet_inputs = {'__inp': dace.pointer(inp_data_desc.dtype)} + if has_scales: + scales_read = nstate.add_read(scales_name) + tasklet_inputs['__scales'] = dace.pointer(scales_data_desc.dtype) + if has_sizes: + sizes_read = nstate.add_read(sizes_name) + tasklet_inputs['__sizes'] = dace.pointer(sizes_data_desc.dtype) + if has_roi: + roi_read = nstate.add_read(roi_name) + tasklet_inputs['__roi'] = dace.pointer(roi_data_desc.dtype) + + # Generate tasklet code for interpolation + tasklet_code = [] + + # Get interpolation parameters + coord_mode = getattr(node, 'coordinate_transformation_mode', 'half_pixel') + mode = getattr(node, 'mode', 'nearest') + antialias = getattr(node, 'antialias', 0) + exclude_outside = getattr(node, 'exclude_outside', 0) + extrapolation_value = getattr(node, 'extrapolation_value', 0.0) + + # Add cubic interpolation helper functions if needed + if mode == 'cubic': + cubic_coeff_a = getattr(node, 'cubic_coeff_a', -0.75) + tasklet_code.append(f""" + // Cubic interpolation helper functions + float cubic_weight(float x) {{ + float a = {cubic_coeff_a}; + float absx = abs(x); + if (absx < 1.0) {{ + return (a + 2.0) * absx * absx * absx - (a + 3.0) * absx * absx + 1.0; + }} else if (absx < 2.0) {{ + return a * absx * absx * absx - 5.0 * a * absx * absx + 8.0 * a * absx - 4.0 * a; + }} + return 0.0; + }} + """) + + # Loop over output dimensions + tasklet_code.append(""" + // Loop over all output dimensions + """) + + # Create nested loops for each dimension + for i in range(len(out_data_desc.shape)): + tasklet_code.append(f"for (int i{i} = 0; i{i} < {out_data_desc.shape[i]}; i{i}++) {{") + + # Calculate input indices + tasklet_code.append(""" + // Calculate input indices for each dimension + int inp_indices[{}]; + """.format(num_dims)) + + # Declare all size variables at the beginning + for i in range(num_dims): + if i in axes: + tasklet_code.append(f"float inp_size_{i};") + tasklet_code.append(f"float out_size_{i};") + + for i in range(num_dims): + tasklet_code.append(f"// Dimension {i}") + if i in axes: + axis_idx = axes.index(i) + if has_scales: + tasklet_code.append(f""" + float scale_{i} = __scales[{axis_idx}]; + inp_size_{i} = {inp_data_desc.shape[i]}; + out_size_{i} = {out_data_desc.shape[i]}; + float x_resized_{i} = i{i}; + float x_original_{i}; + """) + + # Add coordinate transformation based on mode + if coord_mode == 'half_pixel': + tasklet_code.append(f""" + x_original_{i} = (x_resized_{i} + 0.5) / scale_{i} - 0.5; + """) + elif coord_mode == 'half_pixel_symmetric': + tasklet_code.append(f""" + float adjustment_{i} = out_size_{i} / (out_size_{i} - 1); + float center_{i} = inp_size_{i} / 2; + float offset_{i} = center_{i} * (1 - adjustment_{i}); + x_original_{i} = offset_{i} + (x_resized_{i} + 0.5) / scale_{i} - 0.5; + """) + elif coord_mode == 'pytorch_half_pixel': + tasklet_code.append(f""" + x_original_{i} = out_size_{i} > 1 ? (x_resized_{i} + 0.5) / scale_{i} - 0.5 : 0; + """) + elif coord_mode == 'align_corners': + tasklet_code.append(f""" + x_original_{i} = x_resized_{i} * (inp_size_{i} - 1) / (out_size_{i} - 1); + """) + elif coord_mode == 'asymmetric': + tasklet_code.append(f""" + x_original_{i} = x_resized_{i} / scale_{i}; + """) + elif coord_mode == 'tf_crop_and_resize': + tasklet_code.append(f""" + float roi_start_{i} = __roi[{axis_idx}]; + float roi_end_{i} = __roi[{len(axes) + axis_idx}]; + if (out_size_{i} > 1) {{ + x_original_{i} = roi_start_{i} * (inp_size_{i} - 1) + x_resized_{i} * (roi_end_{i} - roi_start_{i}) * (inp_size_{i} - 1) / (out_size_{i} - 1); + }} else {{ + x_original_{i} = 0.5 * (roi_start_{i} + roi_end_{i}) * (inp_size_{i} - 1); + }} + """) + + # Add interpolation mode handling + if mode == 'nearest': + nearest_mode = getattr(node, 'nearest_mode', 'round_prefer_floor') + if nearest_mode == 'floor': + tasklet_code.append(f"inp_indices[{i}] = int(floor(x_original_{i}));") + elif nearest_mode == 'ceil': + tasklet_code.append(f"inp_indices[{i}] = int(ceil(x_original_{i}));") + else: # round_prefer_floor or round_prefer_ceil + tasklet_code.append(f"inp_indices[{i}] = int(round(x_original_{i}));") + elif mode == 'linear': + tasklet_code.append(f""" + float x0_{i} = floor(x_original_{i}); + float x1_{i} = ceil(x_original_{i}); + float w0_{i} = x1_{i} - x_original_{i}; + float w1_{i} = x_original_{i} - x0_{i}; + inp_indices[{i}] = int(x0_{i}); + inp_indices[{i} + {num_dims}] = int(x1_{i}); // Store second index for linear interpolation + """) + elif mode == 'cubic': + tasklet_code.append(f""" + float x0_{i} = floor(x_original_{i}); + float x1_{i} = x0_{i} + 1; + float x2_{i} = x0_{i} + 2; + float x3_{i} = x0_{i} + 3; + float w0_{i} = cubic_weight(x_original_{i} - x0_{i}); + float w1_{i} = cubic_weight(x_original_{i} - x1_{i}); + float w2_{i} = cubic_weight(x_original_{i} - x2_{i}); + float w3_{i} = cubic_weight(x_original_{i} - x3_{i}); + inp_indices[{i}] = int(x0_{i}); + inp_indices[{i} + {num_dims}] = int(x1_{i}); // Store indices for cubic interpolation + inp_indices[{i} + {2*num_dims}] = int(x2_{i}); + inp_indices[{i} + {3*num_dims}] = int(x3_{i}); + """) + else: # has_sizes + tasklet_code.append(f""" + inp_size_{i} = {inp_data_desc.shape[i]}; + out_size_{i} = {out_data_desc.shape[i]}; + inp_indices[{i}] = int(floor(i{i} * inp_size_{i} / out_size_{i})); + """) + else: + tasklet_code.append(f"inp_indices[{i}] = i{i};") + + # Calculate input index + tasklet_code.append(""" + // Calculate input index + int inp_idx = 0; + """) + for i in range(num_dims): + tasklet_code.append(f"inp_idx += inp_indices[{i}] * {inp_data_desc.strides[i]};") + + # Calculate output index + tasklet_code.append(""" + // Calculate output index + int out_idx = 0; + """) + for i in range(num_dims): + tasklet_code.append(f"out_idx += i{i} * {out_data_desc.strides[i]};") + + # Perform interpolation based on mode + if mode == 'linear': + tasklet_code.append(f""" + // Linear interpolation + float x0 = __inp [inp_idx]; + float x1 = __inp [inp_idx + {inp_data_desc.strides[axes[0]]}]; // Second index for linear interpolation + float result = w0 * x0 + w1 * x1; + """) + elif mode == 'cubic': + tasklet_code.append(f""" + // Cubic interpolation + float x0 = __inp [inp_idx]; + float x1 = __inp [inp_idx + {inp_data_desc.strides[axes[0]]}]; + float x2 = __inp [inp_idx + {2*inp_data_desc.strides[axes[0]]}]; + float x3 = __inp [inp_idx + {3*inp_data_desc.strides[axes[0]]}]; + float result = w0 * x0 + w1 * x1 + w2 * x2 + w3 * x3; + """) + else: # nearest or default + tasklet_code.append(""" + // Nearest neighbor interpolation + float result = __inp [inp_idx]; + """) + + # Handle antialiasing if enabled + if antialias == 1 and mode in ['linear', 'cubic']: + tasklet_code.append(""" + // Apply antialiasing filter + float scale = __scales[0]; // Assuming first axis is being resized + if (scale < 1.0) { + float filter_scale = max(1.0, 1.0 / scale); + result *= filter_scale; + } + """) + + # Handle exclude_outside if enabled + if exclude_outside == 1: + tasklet_code.append(f""" + // Handle exclude_outside + bool is_outside = false; + for (int i = 0; i < {num_dims}; i++) {{ + if (inp_indices[i] < 0 || inp_indices[i] >= {inp_data_desc.shape[0]}) {{ + is_outside = true; + break; + }} + }} + if (is_outside) {{ + result = 0.0; + }} + """) + + # Handle extrapolation_value for tf_crop_and_resize + if coord_mode == 'tf_crop_and_resize': + tasklet_code.append(f""" + // Handle extrapolation for tf_crop_and_resize + bool is_outside = false; + for (int i = 0; i < {num_dims}; i++) {{ + if (inp_indices[i] < 0 || inp_indices[i] >= {inp_data_desc.shape[0]}) {{ + is_outside = true; + break; + }} + }} + if (is_outside) {{ + result = {extrapolation_value}; + }} + """) + + # Write the result to output + tasklet_code.append(""" + // Write output + __out [out_idx] = result; + """) + + # Close dimension loops + for i in range(len(out_data_desc.shape)): + tasklet_code.append("}") + + tasklet = nstate.add_tasklet(f'tasklet_reshape', + tasklet_inputs, {'__out': dace.pointer(out_data_desc.dtype)}, + "\n".join(tasklet_code), + language=dace.Language.CPP) + + # Connect tasklet inputs + nstate.add_edge(inp_read, None, tasklet, "__inp", dace.Memlet.from_array(inp_name, inp_data_desc)) + if has_scales: + nstate.add_edge(scales_read, None, tasklet, "__scales", + dace.Memlet.from_array(scales_name, scales_data_desc)) + if has_sizes: + nstate.add_edge(sizes_read, None, tasklet, "__sizes", dace.Memlet.from_array(sizes_name, sizes_data_desc)) + if has_roi: + nstate.add_edge(roi_read, None, tasklet, "__roi", dace.Memlet.from_array(roi_name, roi_data_desc)) + + # Connect tasklet output + nstate.add_edge(tasklet, "__out", out_write, None, dace.Memlet.from_array(out_name, out_data_desc)) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/img_op_implementations.py b/dace/libraries/onnx/op_implementations/img_op_implementations.py new file mode 100644 index 0000000000..1c3727b07c --- /dev/null +++ b/dace/libraries/onnx/op_implementations/img_op_implementations.py @@ -0,0 +1,563 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy +import functools +import typing + +import numpy as np + +import dace +from dace import SDFGState, SDFG, dtypes +from dace.sdfg import nodes, propagation +from dace.transformation.dataflow import MapExpansion, MapCollapse +from dace.sdfg.nodes import Node +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes.onnx_op import ONNXOp +from dace.libraries.onnx.op_implementations.utils import op_implementation, program_for_node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name, in_edge_with_name, out_edge_with_name + + +def _prod(sequence): + return functools.reduce(lambda a, b: a * b, sequence, 1) + + +@op_implementation(op="MaxPool", name="pure") +class PureMaxPool2D(ONNXForward): + """Pure implementation of 2D MaxPool operation.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The MaxPool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + + if "Indices" in {e.src_conn for e in state.out_edges(node)}: + return False + + image_dims = len(X.shape) - 2 + + # Only do 2D for now + if image_dims != 2: + return False + + if node.pads is not None and (len(node.pads) != image_dims * 2): + return False + + if node.strides is not None and len(node.strides) != image_dims: + return False + + if node.auto_pad != 'NOTSET': + return False + + if node.ceil_mode != 0 or node.storage_order != 0: + return False + + if node.dilations is not None and (not all(d == 1 + for d in node.dilations) or len(node.dilations) != image_dims): + return False + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for MaxPool2D. + + :param node: The MaxPool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the MaxPool operation. + """ + X = in_desc_with_name(node, state, sdfg, "X") + Y = out_desc_with_name(node, state, sdfg, "Y") + + image_dims = len(X.shape) - 2 + batch_size = X.shape[0] + num_channels = X.shape[1] + strides = node.strides if node.strides is not None else [1 for _ in range(image_dims)] + pads = node.pads if node.pads is not None else [0 for _ in range(image_dims) * 2] + stride_x, stride_y = strides + assert pads[0] == pads[2] and pads[1] == pads[3] + pad_x, pad_y, _, _ = pads + filter_hx, filter_hy = node.kernel_shape + input_size_x, input_size_y = X.shape[2:] + output_size_x, output_size_y = Y.shape[2:] + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", copy.deepcopy(X)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y)) + nsdfg.arrays["X"].transient = False + nsdfg.arrays["Y"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + Y_write = nstate.add_write("Y") + + # Create tasklet that performs the max pooling operation + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs={"__X": dace.pointer(X.dtype)}, + outputs={"__Y": dace.pointer(Y.dtype)}, + code=f""" + // Initialize output with minimum value + for (int b = 0; b < {batch_size}; b++) {{ + for (int c = 0; c < {num_channels}; c++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + c * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = {dtypes.min_value(Y.dtype)}; + }} + }} + }} + }} + + // Main max pooling computation + for (int b = 0; b < {batch_size}; b++) {{ + for (int c = 0; c < {num_channels}; c++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + for (int hx = 0; hx < {filter_hx}; hx++) {{ + for (int hy = 0; hy < {filter_hy}; hy++) {{ + int sx = hx + out_x * {stride_x} - {pad_x}; + int sy = hy + out_y * {stride_y} - {pad_y}; + + if (0 <= sx && sx < {input_size_x} && 0 <= sy && sy < {input_size_y}) {{ + float input_val = __X[b * {X.strides[0]} + c * {X.strides[1]} + sx * {X.strides[2]} + sy * {X.strides[3]}]; + float& output_val = __Y[b * {Y.strides[0]} + c * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}]; + output_val = max(output_val, input_val); + }} + }} + }} + }} + }} + }} + }} + """, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(X_read, None, tasklet, "__X", dace.Memlet.from_array("X", X)) + nstate.add_edge(tasklet, "__Y", Y_write, None, dace.Memlet.from_array("Y", Y)) + + return nsdfg + + +@op_implementation(op="Conv", name="pure") +class PureConv2D(ONNXForward): + """Convolution implementation with support for grouped and depthwise convolutions.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The Conv ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + W = in_desc_with_name(node, state, sdfg, "W") + try: + B = in_desc_with_name(node, state, sdfg, "B") + except Exception as e: + B = None + + image_dims = len(X.shape) - 2 + num_filters = W.shape[0] + num_channels = X.shape[1] + + if (X.dtype not in [dace.float16, dace.float32, dace.float64] + or W.dtype not in [dace.float16, dace.float32, dace.float64]): + return False + + # Only do 2D for now + if len(X.shape) != 4 or len(W.shape) != 4: + return False + + # Check group convolution constraints + groups = node.group if node.group is not None else 1 + if groups < 1: + return False + + # For grouped convolution: + # - Input channels must be divisible by groups + # - Output channels (num_filters) must be divisible by groups + # - Weight shape[1] should be num_channels // groups + if num_channels % groups != 0: + return False + if num_filters % groups != 0: + return False + if W.shape[1] != num_channels // groups: + return False + + if node.dilations is not None and (not all(d == 1 + for d in node.dilations) or len(node.dilations) != image_dims): + return False + + if node.pads is not None and (len(node.pads) != image_dims * 2): + return False + + if node.strides is not None and (len(node.strides) != image_dims): + return False + + if B is not None and B.shape[0] != num_filters: + return False + + if node.auto_pad != 'NOTSET': + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for Conv2D. + + :param node: The Conv ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the Conv operation. + """ + X = in_desc_with_name(node, state, sdfg, "X") + W = in_desc_with_name(node, state, sdfg, "W") + Y = out_desc_with_name(node, state, sdfg, "Y") + + # Check if bias is present in input connectors + B = in_desc_with_name(node, state, sdfg, "B") if "B" in node.in_connectors else None + + if node.kernel_shape is not None: + filter_hx, filter_hy = node.kernel_shape + else: + filter_hx, filter_hy = W.shape[2:] + + num_filters = W.shape[0] + num_channels = X.shape[1] + batch_size = X.shape[0] + + # Get number of groups (default to 1 for standard convolution) + groups = node.group if node.group is not None else 1 + channels_per_group = num_channels // groups + filters_per_group = num_filters // groups + + input_size_x, input_size_y = X.shape[2:] + output_size_y, output_size_x = Y.shape[2:] + stride_y, stride_x = node.strides or [1, 1] + pad_x, pad_y, _, _ = node.pads or [0, 0, 0, 0] + + dtype = X.dtype + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", copy.deepcopy(X)) + nsdfg.add_datadesc("W", copy.deepcopy(W)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y)) + if B is not None: + nsdfg.add_datadesc("B", copy.deepcopy(B)) + + # Set arrays as non-transient since they are inputs/outputs + nsdfg.arrays["X"].transient = False + nsdfg.arrays["W"].transient = False + nsdfg.arrays["Y"].transient = False + if B is not None: + nsdfg.arrays["B"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + W_read = nstate.add_read("W") + Y_write = nstate.add_write("Y") + if B is not None: + B_read = nstate.add_read("B") + + # Generate C++ code for the grouped convolution + code = f""" + // Initialize output + {f''' + // Initialize with bias + for (int b = 0; b < {batch_size}; b++) {{ + for (int m = 0; m < {num_filters}; m++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + m * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = __B[m]; + }} + }} + }} + }} + ''' if B is not None else f''' + // Zero-initialize output + for (int b = 0; b < {batch_size}; b++) {{ + for (int m = 0; m < {num_filters}; m++) {{ + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + __Y[b * {Y.strides[0]} + m * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] = 0; + }} + }} + }} + }} + '''} + + // Main grouped convolution computation + for (int b = 0; b < {batch_size}; b++) {{ + for (int g = 0; g < {groups}; g++) {{ + // Each group processes a subset of input/output channels + int in_channel_start = g * {channels_per_group}; + int out_channel_start = g * {filters_per_group}; + + for (int m = 0; m < {filters_per_group}; m++) {{ + int out_channel = out_channel_start + m; + + for (int out_x = 0; out_x < {output_size_x}; out_x++) {{ + for (int out_y = 0; out_y < {output_size_y}; out_y++) {{ + // Only convolve with channels in the same group + for (int c = 0; c < {channels_per_group}; c++) {{ + int in_channel = in_channel_start + c; + + for (int hx = 0; hx < {filter_hx}; hx++) {{ + for (int hy = 0; hy < {filter_hy}; hy++) {{ + int sx = hx + out_x * {stride_x} - {pad_x}; + int sy = hy + out_y * {stride_y} - {pad_y}; + + if (0 <= sx && sx < {input_size_x} && 0 <= sy && sy < {input_size_y}) {{ + // Note: Weight tensor layout for grouped conv: + // [num_filters, channels_per_group, filter_hx, filter_hy] + float filter = __W[out_channel * {W.strides[0]} + c * {W.strides[1]} + hx * {W.strides[2]} + hy * {W.strides[3]}]; + float image = __X[b * {X.strides[0]} + in_channel * {X.strides[1]} + sx * {X.strides[2]} + sy * {X.strides[3]}]; + __Y[b * {Y.strides[0]} + out_channel * {Y.strides[1]} + out_x * {Y.strides[2]} + out_y * {Y.strides[3]}] += filter * image; + }} + }} + }} + }} + }} + }} + }} + }} + }} + """ + + # Create tasklet inputs and outputs + tasklet_inputs = { + "__X": dace.pointer(X.dtype), + "__W": dace.pointer(W.dtype), + } + tasklet_outputs = { + "__Y": dace.pointer(Y.dtype), + } + + if B is not None: + tasklet_inputs["__B"] = dace.pointer(B.dtype) + + # Create the tasklet + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=code, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(X_read, None, tasklet, "__X", dace.Memlet.from_array("X", X)) + nstate.add_edge(W_read, None, tasklet, "__W", dace.Memlet.from_array("W", W)) + if B is not None: + nstate.add_edge(B_read, None, tasklet, "__B", dace.Memlet.from_array("B", B)) + nstate.add_edge(tasklet, "__Y", Y_write, None, dace.Memlet.from_array("Y", Y)) + + return nsdfg + + +@op_implementation(op="BatchNormalization", name="pure") +class PureBatchNormalization(ONNXForward): + """Pure implementation of BatchNormalization operation.""" + + @staticmethod + def forward_can_be_applied(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The BatchNormalization ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: True if the implementation can be applied, False otherwise. + """ + X = in_desc_with_name(node, state, sdfg, "X") + if len(X.shape) != 4: + return False + + if "in_mean" in node.in_connectors and "input_mean" not in node.in_connectors: + # Replace the old names with the new ones + node.add_in_connector("input_mean", node.in_connectors["in_mean"]) + node.remove_in_connector("in_mean") + + if "in_var" in node.in_connectors and "input_var" not in node.in_connectors: + # Replace the old names with the new ones + node.add_in_connector("input_var", node.in_connectors["in_var"]) + node.remove_in_connector("in_var") + + # Check for the new output names + if not {"scale", "B", "input_mean", "input_var"}.issubset(node.in_connectors): + return False + + return True + + @staticmethod + def forward(node: ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + """Generate the forward pass implementation for BatchNormalization. + + :param node: The BatchNormalization ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the BatchNormalization operation. + """ + shape = copy.deepcopy(in_desc_with_name(node, state, sdfg, "X").shape) + reduce_axes = list(shape) + num_channels = reduce_axes.pop(1) + + N = _prod(reduce_axes) + broadcast_shape = [num_channels, 1, 1] + dtype = in_desc_with_name(node, state, sdfg, "X").dtype + eps = node.epsilon + momentum = node.momentum + inv_momentum = 1 - node.momentum + + axis = tuple(i for i in range(len(shape)) if i != 1) + + # Check if training_mode attribute exists + if not hasattr(node, "training_mode"): + # By default, set to False (inference mode) + node.training_mode = False + + if node.training_mode: + # TRAINING: compute batch statistics and update running statistics (EMA like PyTorch) + def prog(input_mean, scale, input_var, B, X, Y, running_mean, running_var): + # Batch mean, variance over axis=(0,2,3) for NCHW (your `axis`/`N` already set) + batch_mean = np.add.reduce(X, axis=axis) / N + + batch_mean_broadcastable = dace.define_local(broadcast_shape, dtype) + batch_mean_broadcastable[:] = batch_mean + X_minus_mean = X - batch_mean_broadcastable + + batch_var = np.add.reduce(X_minus_mean * X_minus_mean, axis=axis) / N + batch_var_eps = np.reshape(batch_var + eps, broadcast_shape) + + inv_std = dace.elementwise(lambda x: dtype(1.0) / sqrt(x), batch_var_eps) + normalized = X_minus_mean * inv_std + + scale_reshaped = np.reshape(scale, broadcast_shape) + bias_reshaped = np.reshape(B, broadcast_shape) + Y[:] = normalized * scale_reshaped + bias_reshaped + + # FIXED: PyTorch EMA + # running = (1 - momentum) * running + momentum * batch + running_mean[:] = input_mean * (1.0 - momentum) + batch_mean * momentum + running_var[:] = input_var * (1.0 - momentum) + batch_var * momentum + + new_sdfg = program_for_node(prog, sdfg, state, node) + + # Keep your "write-back" edges as-is + new_state = sdfg.add_state_after(sdfg.nodes()[0]) + rm_name = out_edge_with_name(node, state, "running_mean").data.data + new_state.add_edge(new_state.add_read(rm_name), None, + new_state.add_read(in_edge_with_name(node, state, "input_mean").data.data), None, + sdfg.make_array_memlet(rm_name)) + rv_name = out_edge_with_name(node, state, "running_var").data.data + new_state.add_edge(new_state.add_read(rv_name), None, + new_state.add_read(in_edge_with_name(node, state, "input_var").data.data), None, + sdfg.make_array_memlet(rv_name)) + else: + # EVAL: use provided running statistics; DO NOT recompute mean/var + def prog(input_mean, scale, input_var, B, X, Y): + mean_b = dace.define_local(broadcast_shape, dtype) + var_b = dace.define_local(broadcast_shape, dtype) + mean_b[:] = input_mean + var_b[:] = input_var + + X_minus_mean = X - mean_b + inv_std = dace.elementwise(lambda x: dtype(1.0) / sqrt(x + eps), var_b) + + normalized = X_minus_mean * inv_std + scale_b = np.reshape(scale, broadcast_shape) + bias_b = np.reshape(B, broadcast_shape) + Y[:] = normalized * scale_b + bias_b + + new_sdfg = program_for_node(prog, sdfg, state, node) + + return new_sdfg + + +@op_implementation(op="GlobalAveragePool", name="pure") +class PureGlobalAveragePool(ONNXForward): + """Pure implementation of GlobalAveragePool operation.""" + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + """Check if this implementation can be applied to the given node. + + :param node: The GlobalAveragePool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: Always True for this implementation. + """ + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + """Generate the forward pass implementation for GlobalAveragePool. + + :param node: The GlobalAveragePool ONNX node. + :param state: The SDFG state containing the node. + :param sdfg: The parent SDFG. + :return: A nested SDFG implementing the GlobalAveragePool operation. + """ + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXReduceMean + + # Get input and output descriptors + X_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, "X")) + Y_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "Y")) + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("X", X_desc) + nsdfg.add_datadesc("Y", Y_desc) + nsdfg.arrays["X"].transient = False + nsdfg.arrays["Y"].transient = False + + # Add access nodes + X_read = nstate.add_read("X") + Y_write = nstate.add_write("Y") + + # Create axes array for reduction over spatial dimensions (2, 3) + axes_name = "axes" + rank = len(X_desc.shape) # e.g., (N, C, H, W) -> 4 + axes_values = list(range(2, rank)) + axes_arr_dtype = dace.int64 + axes_arr_shape = [len(axes_values)] + _, axes_desc = nsdfg.add_array(axes_name, axes_arr_shape, axes_arr_dtype, transient=True) + axes_node = nstate.add_access(axes_name) + + # Add a tasklet to initialize the axes array + axes_init_tasklet = nstate.add_tasklet("init_axes", + set(), {"out": dace.pointer(axes_arr_dtype)}, + "\n".join( + [f"out [{idx}] = {val};" for idx, val in enumerate(axes_values)]), + language=dace.Language.CPP) + nstate.add_edge(axes_init_tasklet, "out", axes_node, None, dace.Memlet(f"{axes_name}[0:{len(axes_values)}]")) + + # Create ONNXReduceMean node + reduce_mean_op = ONNXReduceMean("reduce_mean", keepdims=1) + reduce_mean_op.axes = axes_values + nstate.add_node(reduce_mean_op) + reduce_mean_op.add_in_connector("data") + reduce_mean_op.add_in_connector("axes") + reduce_mean_op.add_out_connector("reduced") + + # Connect the ReduceMean operation + nstate.add_edge(X_read, None, reduce_mean_op, "data", nsdfg.make_array_memlet("X")) + nstate.add_edge(axes_node, None, reduce_mean_op, "axes", nsdfg.make_array_memlet(axes_name)) + nstate.add_edge(reduce_mean_op, "reduced", Y_write, None, nsdfg.make_array_memlet("Y")) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/linalg_ops.py b/dace/libraries/onnx/op_implementations/linalg_ops.py new file mode 100644 index 0000000000..ca5c2bd9ad --- /dev/null +++ b/dace/libraries/onnx/op_implementations/linalg_ops.py @@ -0,0 +1,359 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Linear algebra operations for ONNX. + +This module contains implementations of linear algebra operations including: +- MatMul: Matrix multiplication with broadcasting +- Gemm: General matrix multiplication (alpha*A*B + beta*C) +- Einsum: Einstein summation notation for tensor operations + +""" + +import copy +import itertools +import typing + +import dace +from dace import SDFG, SDFGState, nodes +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace import config +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.utils import in_desc_with_name, op_implementation, out_desc_with_name +from dace.frontend.common import create_einsum_sdfg + +# ============================================================================ +# Matrix Multiplication +# ============================================================================ + + +@op_implementation(op="MatMul", name="pure") +class PureMatMul(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + input0_dim = len(in_desc_with_name(node, state, sdfg, "A").shape) + input1_dim = len(in_desc_with_name(node, state, sdfg, "B").shape) + + if input0_dim == 1 or input1_dim == 1: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + + A_desc = in_desc_with_name(node, state, sdfg, "A") + B_desc = in_desc_with_name(node, state, sdfg, "B") + Y_desc = out_desc_with_name(node, state, sdfg, "Y") + input0_dim = A_desc.shape + input1_dim = B_desc.shape + + # list containing letters from z-a + letters = [chr(ord('z') - i) for i in range(26)] + # i j k are used for the last dimensions + letters = [l for l in letters if l not in ['i', 'j', 'k']] + + if len(input0_dim) == 1: + if len(input1_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'k' + arg2 = 'kj' + result = 'j' + elif len(input1_dim) == 1: + if len(input0_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'ik' + arg2 = 'k' + result = 'i' + else: + # build the einsum. The last two dimensions are always just the matrix multiply einsum + # dace will later specialize to a batched matmul if possible + arg1 = 'ik' + arg2 = 'kj' + result = 'ij' + if input0_dim[-2] != input0_dim[-1]: + if dace.symbolic.issymbolic(input0_dim[-2]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" + ) + new_shape = list(A_desc.shape) + new_shape[-1] = input1_dim[-2] + A_desc.shape = new_shape + elif dace.symbolic.issymbolic(input1_dim[-1]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" + ) + new_shape = list(B_desc.shape) + new_shape[-2] = input0_dim[-1] + B_desc.shape = new_shape + input0_dim = input0_dim[:-2] + input1_dim = input1_dim[:-2] + for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): + if dim0 is None: + # only dim0 exists + letter = letters.pop() + arg2 = letter + arg2 + result = letter + result + elif dim1 is None: + # only dim1 exists + letter = letters.pop() + arg1 = letter + arg1 + result = letter + result + else: + # both exist + letter = letters.pop() + arg1 = letter + arg1 + arg2 = letter + arg2 + result = letter + result + + einsum_str = '{},{}->{}'.format(arg1, arg2, result) + + # we lower to an ONNXEinsum node instead straight to the dace einsum to + # make the autodiff simpler + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + einsum_node: onnx_op.ONNXOp = ONNXEinsum(node.label + "_einsum_expansion", equation=einsum_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + einsum_node.add_in_connector("Inputs__1") + nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) + nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) + nsdfg.arrays["A"].transient = False + nsdfg.arrays["B"].transient = False + nsdfg.arrays["Y"].transient = False + + nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) + nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) + nstate.add_edge(einsum_node, "Output", nstate.add_write("Y"), None, nsdfg.make_array_memlet("Y")) + + return nsdfg + + +# ============================================================================ +# Einstein Summation +# ============================================================================ + + +@op_implementation(op="Einsum", name="pure") +class PureEinsum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + if "..." in node.equation: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + for e in node.iter_inputs_in_onnx_order(state): + desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, e.dst_conn)) + desc.transient = False + nsdfg.add_datadesc(e.dst_conn, desc) + for e in node.iter_outputs_in_onnx_order(state): + desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, e.src_conn)) + desc.transient = False + nsdfg.add_datadesc(e.src_conn, desc) + + # Check if there is a wcr sum to accumulate the result instead of initialization the output + # This is necessary for gradient accumulation to be consistent + output_edge = state.out_edges(node) + assert len(output_edge) == 1, "Einsum node should have exactly one output edge" + output_edge = output_edge[0] + beta = 1 if output_edge.data.wcr else 0 + create_einsum_sdfg(nsdfg, + nstate, + node.equation.replace(" ", ""), + *(e.dst_conn for e in node.iter_inputs_in_onnx_order(state)), + output="Output", + beta=beta) + return nsdfg + + +# ============================================================================ +# General Matrix Multiplication (Gemm) +# ============================================================================ + + +@op_implementation(op="Gemm", name="pure") +class PureGemm(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + from dace.libraries.onnx.nodes.onnx_op_registry import ONNXEinsum # avoid import loop + A_desc = in_desc_with_name(node, state, sdfg, "A") + B_desc = in_desc_with_name(node, state, sdfg, "B") + Y_desc = out_desc_with_name(node, state, sdfg, "Y") + input0_dim = A_desc.shape + input1_dim = B_desc.shape + + # list containing letters from z-a + letters = [chr(ord('z') - i) for i in range(26)] + # i j k are used for the last dimensions + letters = [l for l in letters if l not in ['i', 'j', 'k']] + + if len(input0_dim) == 1: + if len(input1_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'k' + arg2 = 'kj' + result = 'j' + elif len(input1_dim) == 1: + if len(input0_dim) != 2: + raise ValueError("invalid dimensions") + arg1 = 'ik' + arg2 = 'k' + result = 'i' + else: + # build the einsum. The last two dimensions are always just the matrix multiply einsum + # dace will later specialize to a batched matmul if possible + arg1 = 'ik' + arg2 = 'kj' + result = 'ij' + if input0_dim[-2] != input0_dim[-1]: + if dace.symbolic.issymbolic(input0_dim[-2]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" + ) + new_shape = list(A_desc.shape) + new_shape[-1] = input1_dim[-2] + A_desc.shape = new_shape + elif dace.symbolic.issymbolic(input1_dim[-1]): + if config.Config.get_bool('debugprint'): + print( + f"Warning: overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" + ) + new_shape = list(B_desc.shape) + new_shape[-2] = input0_dim[-1] + B_desc.shape = new_shape + input0_dim = input0_dim[:-2] + input1_dim = input1_dim[:-2] + for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): + if dim0 is None: + # only dim0 exists + letter = letters.pop() + arg2 = letter + arg2 + result = letter + result + elif dim1 is None: + # only dim1 exists + letter = letters.pop() + arg1 = letter + arg1 + result = letter + result + else: + # both exist + letter = letters.pop() + arg1 = letter + arg1 + arg2 = letter + arg2 + result = letter + result + + if node.transA == 1: + arg1 = ''.join(reversed(arg1)) + if node.transB == 1: + arg2 = ''.join(reversed(arg2)) + + einsum_str = '{},{}->{}'.format(arg1, arg2, result) + + # we lower to an ONNXEinsum node instead straight to the dace einsum to + # make the autodiff simpler + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Einsum: "A", "B" -> mm_result + einsum_node: nodes.LibraryNode = ONNXEinsum(node.label + "_einsum_expansion", equation=einsum_str) + + nstate.add_node(einsum_node) + einsum_node.add_in_connector("Inputs__0") + einsum_node.add_in_connector("Inputs__1") + nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) + nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) + nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) + nsdfg.arrays["A"].transient = False + nsdfg.arrays["B"].transient = False + nsdfg.arrays["Y"].transient = False + + # Decide on array names based on alpha and beta + uid = state.node_id(node) + mm_result = "Y" + if node.alpha != 1 or node.beta != 0: + mm_result = f"Ytmp_{uid}" + scal_result = mm_result + if node.alpha != 1: + scal_result = f"scaled_{uid}" + + # Create arrays according to alpha and beta + if node.alpha != 1 or node.beta != 0: + Ytmp_desc = out_desc_with_name(node, state, sdfg, "Y") + nsdfg.add_datadesc(f"Ytmp_{uid}", copy.deepcopy(Ytmp_desc)) + nsdfg.arrays[f"Ytmp_{uid}"].transient = True + if node.beta != 0: + beta_desc = out_desc_with_name(node, state, sdfg, "Y") + nsdfg.add_datadesc(f"scaled_{uid}", copy.deepcopy(beta_desc)) + nsdfg.arrays[f"scaled_{uid}"].transient = True + + nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) + nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) + mm_result_node = nstate.add_write(mm_result) + nstate.add_edge(einsum_node, "Output", mm_result_node, None, nsdfg.make_array_memlet(mm_result)) + + # Multiply by alpha: mm_result -> scal_result + if node.alpha != 1: + nstate.add_mapped_tasklet( + node.label + '_alphascale', + { + k: f'0:{Ytmp_desc.shape[i]}' + for i, k in enumerate(result) + }, + dict(a=dace.Memlet(data=mm_result, subset=','.join(result))), + f'o = a * dace.{Ytmp_desc.dtype}({node.alpha})', + dict(o=dace.Memlet(data=scal_result, subset=','.join(result))), + external_edges=True, + input_nodes=dict(a=mm_result_node), + ) + + # Multiply by beta: scal_result, "C" -> "Y" + if node.beta != 0: + C_desc = in_desc_with_name(node, state, sdfg, "C") + nsdfg.add_datadesc("C", copy.deepcopy(C_desc)) + nsdfg.arrays["C"].transient = False + scal_result_node = next(n for n in nstate.sink_nodes() + if isinstance(n, dace.nodes.AccessNode) and n.data == scal_result) + beta_scale_code = f'o = s + c * dace.{C_desc.dtype}({node.beta})' + if node.beta == 1: + beta_scale_code = f'o = s + c' + + # Support broadcasting in C -> Y + c_index = result[-len(C_desc.shape):] + for c_shp, y_shp in zip(reversed(C_desc.shape), reversed(Y_desc.shape)): + if c_shp != y_shp: + raise ValueError('Could not broadcast dimensions from C ' + 'to Y in ONNXGemm') + + nstate.add_mapped_tasklet( + node.label + '_betascale', + { + k: f'0:{Y_desc.shape[i]}' + for i, k in enumerate(result) + }, + dict(s=dace.Memlet(data=scal_result, subset=','.join(result)), + c=dace.Memlet(data="C", subset=','.join(c_index))), + beta_scale_code, + dict(o=dace.Memlet(data="Y", subset=','.join(result))), + external_edges=True, + input_nodes={scal_result: scal_result_node}, + ) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/normalization_ops.py b/dace/libraries/onnx/op_implementations/normalization_ops.py new file mode 100644 index 0000000000..f434e3d5ac --- /dev/null +++ b/dace/libraries/onnx/op_implementations/normalization_ops.py @@ -0,0 +1,281 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Normalization operations for ONNX. + +This module contains implementations of normalization operations including: +- Softmax, LogSoftmax: Softmax normalization +- LayerNormalization: Layer normalization +- Dropout: Dropout regularization +""" + +import copy +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState, nodes +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.op_implementations.utils import (in_desc_with_name, op_implementation, out_desc_with_name, + python_pure_op_implementation) + +# ============================================================================ +# Softmax Operations +# ============================================================================ + +softmax_compute = dict(axis=lambda node, input: tuple(range(len(input.shape)))[node.axis:]) + + +@python_pure_op_implementation(**softmax_compute) +def Softmax(input, output): + maximum = np.maximum.reduce(input, axis=axis, keepdims=True) + exp_values = np.exp(input - maximum) + sum_exp = np.add.reduce(exp_values, axis=axis, keepdims=True) + output[:] = exp_values / sum_exp + + +@python_pure_op_implementation(**softmax_compute) +def LogSoftmax(input, output): + maximum = np.maximum.reduce(input, axis=axis, keepdims=True) + max_sub = input - maximum + exponent = np.exp(max_sub) + sum = np.add.reduce(exponent, axis=axis, keepdims=True) + log_sum = np.log(sum) + output[:] = max_sub - log_sum + + +# ============================================================================ +# Layer Normalization +# ============================================================================ + + +def _layernorm_axis(node, X): + axis = node.axis if hasattr(node, 'axis') and node.axis >= 0 else len(X.shape) + node.axis + return tuple(range(axis, len(X.shape))) + + +def _layernorm_norm_size(node, X): + axis = node.axis if hasattr(node, 'axis') and node.axis >= 0 else len(X.shape) + node.axis + return int(np.prod([X.shape[i] for i in range(axis, len(X.shape))])) + + +def _layernorm_epsilon(node, X): + eps = getattr(node, 'epsilon', 1e-5) + return X.dtype.type(eps) + + +def _layernorm_one(X): + return X.dtype.type(1) + + +layernorm_compute = dict(axis=_layernorm_axis, + epsilon=_layernorm_epsilon, + norm_size=_layernorm_norm_size, + one=_layernorm_one) + + +@python_pure_op_implementation(**layernorm_compute) +def LayerNormalization(X, Scale, B, Y): + sum_x = np.add.reduce(X, axis=axis, keepdims=True) + mean = sum_x / norm_size + diff = X - mean + sum_sq = np.add.reduce(diff * diff, axis=axis, keepdims=True) + variance = sum_sq / norm_size + inv_std = one / np.sqrt(variance + epsilon) + normalized = diff * inv_std + Y[:] = normalized * Scale + B + + +# ============================================================================ +# Dropout +# ============================================================================ + + +@op_implementation(op="Dropout", name="pure") +class PureDropout(ONNXForward): + """ Dropout implementation with support for training and inference modes. + """ + + @staticmethod + def forward_can_be_applied(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> bool: + # Get input descriptor + data = in_desc_with_name(node, state, sdfg, "data") + + # Check if optional inputs are present + has_ratio = "ratio" in node.in_connectors + has_training_mode = "training_mode" in node.in_connectors + + # Check data type + if data.dtype not in [dace.float16, dace.float32, dace.float64]: + return False + + # If ratio is provided as input, it should be a scalar + if has_ratio: + ratio = in_desc_with_name(node, state, sdfg, "ratio") + if ratio.total_size != 1: + return False + + # If training_mode is provided as input, it should be a scalar boolean + if has_training_mode: + training_mode = in_desc_with_name(node, state, sdfg, "training_mode") + if training_mode.total_size != 1: + return False + + return True + + @staticmethod + def forward(node: 'ONNXOp', state: SDFGState, sdfg: SDFG) -> typing.Union[nodes.Node, SDFG]: + # Get descriptors + data = in_desc_with_name(node, state, sdfg, "data") + output = out_desc_with_name(node, state, sdfg, "output") + + # Check for optional mask output + has_mask_output = "mask" in node.out_connectors + mask = out_desc_with_name(node, state, sdfg, "mask") if has_mask_output else None + + # Check for optional inputs + has_ratio_input = "ratio" in node.in_connectors + has_training_mode_input = "training_mode" in node.in_connectors + + ratio_desc = in_desc_with_name(node, state, sdfg, "ratio") if has_ratio_input else None + training_mode_desc = in_desc_with_name(node, state, sdfg, "training_mode") if has_training_mode_input else None + + # Get dropout ratio (from attribute or will be provided as input) + # ONNX spec: default ratio is 0.5 if not specified + dropout_ratio = getattr(node, 'ratio', 0.5) if not has_ratio_input else None + + # Get seed if specified (for reproducible dropout) + seed = getattr(node, 'seed', None) + + # Calculate total elements + total_elements = data.total_size + + # Get data type + dtype = data.dtype + dtype_str = str(dtype).replace("dace.", "") + + # Create new SDFG + nsdfg = dace.SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + # Add data descriptors + nsdfg.add_datadesc("data", copy.deepcopy(data)) + nsdfg.add_datadesc("output", copy.deepcopy(output)) + + if has_mask_output: + nsdfg.add_datadesc("mask", copy.deepcopy(mask)) + + if has_ratio_input: + nsdfg.add_datadesc("ratio", copy.deepcopy(ratio_desc)) + + if has_training_mode_input: + nsdfg.add_datadesc("training_mode", copy.deepcopy(training_mode_desc)) + + # Set arrays as non-transient + nsdfg.arrays["data"].transient = False + nsdfg.arrays["output"].transient = False + if has_mask_output: + nsdfg.arrays["mask"].transient = False + if has_ratio_input: + nsdfg.arrays["ratio"].transient = False + if has_training_mode_input: + nsdfg.arrays["training_mode"].transient = False + + # Add access nodes + data_read = nstate.add_read("data") + output_write = nstate.add_write("output") + mask_write = nstate.add_write("mask") if has_mask_output else None + ratio_read = nstate.add_read("ratio") if has_ratio_input else None + training_mode_read = nstate.add_read("training_mode") if has_training_mode_input else None + + # Generate C++ code for dropout + # Note: This implementation uses a simple linear congruential generator for portability + # In production, you might want to use a better random number generator + + code = f""" + #include + #include + + // Get dropout ratio + {dtype_str} ratio = {dropout_ratio if not has_ratio_input else '__ratio'}; + + // Get training mode (default to false if not specified) + bool training_mode = {('__training_mode' if has_training_mode_input else 'false')}; + + // If in inference mode, just copy input to output + if (!training_mode) {{ + for (int i = 0; i < {total_elements}; i++) {{ + __output[i] = __data[i]; + {"__mask[i] = true;" if has_mask_output else ""} + }} + }} else {{ + // Training mode: apply dropout + + // Initialize random seed + static uint64_t rng_state = {seed if seed is not None else 'uint64_t(std::time(nullptr))'}; + + // Scale factor for remaining values (1 / (1 - ratio)) + {dtype_str} scale = ({dtype_str})(1.0 / (1.0 - ratio)); + + // Apply dropout + for (int i = 0; i < {total_elements}; i++) {{ + // Simple LCG for random number generation + // This generates a random number in [0, 1) + rng_state = (rng_state * 1664525ULL + 1013904223ULL); + double random_val = double(rng_state) / double(UINT64_MAX); + + // Dropout: keep if random value is greater than ratio + bool keep = (random_val >= ratio); + + if (keep) {{ + // Scale the kept values + __output[i] = __data[i] * scale; + {"__mask[i] = true;" if has_mask_output else ""} + }} else {{ + // Drop the value + __output[i] = 0; + {"__mask[i] = false;" if has_mask_output else ""} + }} + }} + }} + """ + + # Create tasklet inputs and outputs + tasklet_inputs = { + "__data": dace.pointer(data.dtype), + } + tasklet_outputs = { + "__output": dace.pointer(output.dtype), + } + + if has_ratio_input: + tasklet_inputs["__ratio"] = ratio_desc.dtype + if has_training_mode_input: + tasklet_inputs["__training_mode"] = training_mode_desc.dtype + if has_mask_output: + tasklet_outputs["__mask"] = dace.pointer(mask.dtype) + + # Create the tasklet + tasklet = nstate.add_tasklet(name=node.label + "_tasklet", + inputs=tasklet_inputs, + outputs=tasklet_outputs, + code=code, + language=dace.Language.CPP) + + # Connect the tasklet with memlets + nstate.add_edge(data_read, None, tasklet, "__data", dace.Memlet.from_array("data", data)) + + if has_ratio_input: + nstate.add_edge(ratio_read, None, tasklet, "__ratio", dace.Memlet.from_array("ratio", ratio_desc)) + + if has_training_mode_input: + nstate.add_edge(training_mode_read, None, tasklet, "__training_mode", + dace.Memlet.from_array("training_mode", training_mode_desc)) + + nstate.add_edge(tasklet, "__output", output_write, None, dace.Memlet.from_array("output", output)) + + if has_mask_output: + nstate.add_edge(tasklet, "__mask", mask_write, None, dace.Memlet.from_array("mask", mask)) + + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/reduction_ops.py b/dace/libraries/onnx/op_implementations/reduction_ops.py new file mode 100644 index 0000000000..cd9a361b8e --- /dev/null +++ b/dace/libraries/onnx/op_implementations/reduction_ops.py @@ -0,0 +1,304 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Reduction operations for ONNX. + +This module contains implementations of reduction operations including: +- ReduceSum, ReduceMean: Standard reductions over specified axes +- ReduceMax, ReduceMin: Min/max reductions +- CumSum: Cumulative sum along an axis +- Sum: Element-wise sum of multiple inputs + +""" + +import copy +import typing + +import dace +import numpy as np +from dace import SDFG, SDFGState +from dace.sdfg.nodes import Node +from dace.sdfg.utils import in_desc_with_name, in_edge_with_name, out_desc_with_name +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.op_implementations.common import iterables_equal +from dace.libraries.onnx.op_implementations.utils import (empty_sdfg_for_node, in_desc_with_name, op_implementation, + out_desc_with_name, program_for_node) + +# ============================================================================ +# Cumulative Sum +# ============================================================================ + + +@op_implementation(op="CumSum", name="pure") +class PureCumSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + if node.exclusive or node.reverse: + return False + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axis").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + return False + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axis = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axis").src.data].numpy().item() + + def prog(x, y): + y[:] = np.cumsum(x, axis=axis) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceMean Operations +# ============================================================================ + + +@op_implementation(op="ReduceMean", name="pure") +class PureReduceMean(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.mean(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceSum Operations +# ============================================================================ + + +@op_implementation(op="ReduceSum", name="pure") +class PureReduceSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.sum(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# ReduceMax and ReduceMin Operations +# ============================================================================ + + +@op_implementation(op="ReduceMax", name="pure") +class PureReduceMax(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.max(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +@op_implementation(op="ReduceMin", name="pure") +class PureReduceMin(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + is_axes_present = True + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data not in sdfg._parent_onnx_model.clean_weights: + return False + except ValueError: + is_axes_present = False + + if not is_axes_present and hasattr(node, "axes"): + is_axes_present = True + + if not is_axes_present: + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + axes = None + try: + if hasattr(sdfg, "_parent_onnx_model") and in_edge_with_name( + node, state, "axes").src.data in sdfg._parent_onnx_model.clean_weights: + axes = sdfg._parent_onnx_model.clean_weights[in_edge_with_name(node, state, "axes").src.data].numpy() + except ValueError: + pass + if axes is not None: + if len(axes) == 1: + axes = axes[0] + else: + axes = tuple(axes) + else: + axes = node.axes if hasattr(node, "axes") else None + + def prog(data, reduced): + reduced[:] = np.min(data, axis=axes) + + return program_for_node(prog, sdfg, state, node) + + +# ============================================================================ +# Sum (Multi-input sum) +# ============================================================================ + + +@op_implementation(op="Sum", name="pure") +class PureSum(ONNXForward): + + @staticmethod + def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: + # check that all shapes are arrays, and that the shapes are all equal + shape = None + for edge in node.iter_inputs_in_onnx_order(state): + desc = in_desc_with_name(node, state, sdfg, edge.dst_conn) + if shape is None: + shape = desc.shape + + if not iterables_equal(shape, desc.shape): + return False + + if not iterables_equal(shape, out_desc_with_name(node, state, sdfg, "sum").shape): + return False + + return True + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: + + nsdfg = dace.SDFG(node.name) + input_names = [] + for e in node.iter_inputs_in_onnx_order(state): + new_desc = copy.deepcopy(in_desc_with_name(node, state, sdfg, e.dst_conn)) + new_desc.transient = False + nsdfg.add_datadesc(e.dst_conn, new_desc) + input_names.append(e.dst_conn) + + new_desc = copy.deepcopy(out_desc_with_name(node, state, sdfg, "sum")) + new_desc.transient = False + nsdfg.add_datadesc("sum", new_desc) + + nstate = nsdfg.add_state() + # we know all shapes are equal to the output shape + shape = out_desc_with_name(node, state, sdfg, "sum").shape + map_ranges = {f"i{i}": f"0:{s}" for i, s in enumerate(shape)} + index_str = f"{', '.join(map_ranges.keys())}" + tasklet, _, _ = nstate.add_mapped_tasklet( + node.name + "_tasklet", + map_ranges=map_ranges, + inputs={f"__{inp}": dace.Memlet(f"{inp}[{index_str}]") + for inp in input_names}, + code=f"__sum = {' + '.join(f'__{inp}' for inp in input_names)}", + outputs={"__sum": dace.Memlet(f"sum[{index_str}]")}, + external_edges=True) + + tasklet.in_connectors = {f"__{inp}": in_desc_with_name(node, state, sdfg, inp).dtype for inp in input_names} + tasklet.out_connectors = {"__sum": out_desc_with_name(node, state, sdfg, "sum").dtype} + return nsdfg diff --git a/dace/libraries/onnx/op_implementations/utils.py b/dace/libraries/onnx/op_implementations/utils.py new file mode 100644 index 0000000000..620d9b3081 --- /dev/null +++ b/dace/libraries/onnx/op_implementations/utils.py @@ -0,0 +1,223 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import inspect +import copy +from typing import Dict, Tuple, Optional, Callable, Union, Any +import functools +import textwrap + +import dace +from dace import SDFGState, SDFG, dtypes, nodes +from dace.frontend.python.parser import DaceProgram +from dace.registry import autoregister + +from dace.libraries.onnx.nodes import onnx_op +from dace.libraries.onnx.forward_implementation_abc import ONNXForward +from dace.libraries.onnx.nodes.node_utils import parse_variadic_param +from dace.sdfg.utils import in_desc_with_name, out_desc_with_name + + +def op_implementation(op, name): + """A decorator that registers an op implementation. + + It should be used on classes that extend :class:`~dace.libraries.onnx.forward_implementation_abc.ONNXForward`. + + :param op: The ONNX name of the op to register for. + :param name: The name of the implementation. + """ + + def dec(cls): + if cls.__doc__ is not None: + cls.__doc__ +=\ + """ + :Implementation name: ``"{}"`` + """.format(name) + else: + cls.__doc__ =\ + """ + :Implementation name: ``"{}"`` + """.format(name) + + return autoregister(cls, op=op, name=name) + + return dec + + +def program_for_node(program, + sdfg: SDFG, + state: SDFGState, + node: onnx_op.ONNXOp, + extra_vars: Optional[Dict[str, Any]] = None) -> SDFG: + """Expand a function to a DaCe program. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + All inputs that are not specified as parameters will be removed using + constant_folding.remove_node_and_computation. + + :param program: The function to expand into a DaCe program. + :param sdfg: The parent SDFG. + :param state: The SDFG state containing the node. + :param node: The ONNX node to create a program for. + :param extra_vars: Optional extra variables to add to the program. + :return: A new SDFG implementing the program. + """ + + from dace.transformation.onnx import constant_folding # avoid import loop + input_names = node.schema.non_variadic_inputs() + variadic_input_names = node.schema.variadic_inputs() + + output_names = node.schema.non_variadic_outputs() + variadic_output_names = node.schema.variadic_outputs() + + if set(input_names).intersection(output_names): + # This is currently the case for only one ONNX op + raise ValueError("program_for_node cannot be applied on nodes of this type;" + " '{}' are both an input and an output".format(set(input_names).intersection(output_names))) + + params = inspect.signature(program).parameters + connectors_to_remove = set(input_names).difference(params) + + annotations = {} + for name, param in params.items(): + if name in input_names or ("__" in name and parse_variadic_param(name)[0] in variadic_input_names): + annotations[name] = in_desc_with_name(node, state, sdfg, name) + elif name in output_names or ("__" in name and parse_variadic_param(name)[0] in variadic_output_names): + annotations[name] = out_desc_with_name(node, state, sdfg, name) + else: + raise ValueError("'{}' was not found as an input or output for {}".format(name, node.schema.name)) + + program.__annotations__ = annotations + + program.__name__ = node.label + "_expansion" + result = DaceProgram(program, (), {}, False, dace.DeviceType.CPU) + if extra_vars is not None: + result.global_vars.update(extra_vars) + + for conn in connectors_to_remove: + constant_folding.remove_node_and_computation(sdfg, state, node, conn) + + sdfg = result.to_sdfg() + + if node.schedule in [dtypes.ScheduleType.GPU_Default] + dtypes.GPU_SCHEDULES: + sdfg.apply_gpu_transformations() + + return sdfg + + +def empty_sdfg_for_node( + sdfg: SDFG, + state: SDFGState, + node: onnx_op.ONNXOp, + add_access_nodes=True) -> Tuple[SDFG, SDFGState, Dict[str, nodes.AccessNode], Dict[str, nodes.AccessNode]]: + """Given a node, return an SDFG that can be used as a nested SDFG expansion for that node. + + The dtypes for the arguments will be extracted by matching the parameter names to edges. + + :param sdfg: The parent SDFG. + :param state: The SDFG state containing the node. + :param node: The ONNX node to create an SDFG for. + :param add_access_nodes: Whether to add access nodes to the SDFG. + :return: A tuple containing (nested SDFG, nested state, input nodes dict, output nodes dict). + """ + nsdfg = SDFG(node.label + "_expansion") + nstate = nsdfg.add_state() + + input_nodes = {} + output_nodes = {} + for edge, is_input in node.iter_edges(state, ignore_unknown=True): + if is_input: + conn_name = edge.dst_conn + nsdfg.add_datadesc(conn_name, copy.deepcopy(in_desc_with_name(node, state, sdfg, conn_name))) + if add_access_nodes: + input_nodes[conn_name] = nstate.add_read(conn_name) + else: + conn_name = edge.src_conn + nsdfg.add_datadesc(conn_name, copy.deepcopy(out_desc_with_name(node, state, sdfg, conn_name))) + if add_access_nodes: + output_nodes[conn_name] = nstate.add_write(conn_name) + nsdfg.arrays[conn_name].transient = False + + return nsdfg, nstate, input_nodes, output_nodes + + +@dace.dtypes.paramdec +def python_pure_op_implementation(func, **compute: Dict[str, Callable]): + """A decorator that registers a Python op implementation. + + The name of the function will be the name of the op that is being replaced. + + The compute parameter enables you to compute a variable given the node and + its inputs/outputs. This variable will be namespaced when parsing the function. + + To use this, the argument names of the functions can be either: + + * ``node``, in which case the argument will be passed the node we are expanding, + * or, the name of any connector of the node, in which case the argument will be + the data descriptor for that connector + + For example, the following compute argument instantiation will make + variables ``axis`` and ``shape`` available when the function is parsed. + + + .. highlight:: python + .. code-block:: python + + compute=dict( + # Grabs the axis of a node + axis=lambda node: node.axis + # Grabs the shape of the connector with name 'data' + shape=lambda data: data.shape + ) + + :param func: The function to register as an implementation + :param compute: A dictionary of functions that compute variables. + """ + + @op_implementation(op=func.__name__, name="pure") + class PureImpl(ONNXForward): + + @staticmethod + def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> Union[nodes.Node, SDFG]: + + def compute_argument_resolver(arg: str): + if arg == "node": + return node + elif arg in node.in_connectors: + return in_desc_with_name(node, state, sdfg, arg) + elif arg in node.out_connectors: + return out_desc_with_name(node, state, sdfg, arg) + else: + raise ValueError("Got unknown compute argument {}." + " Arguments to compute can be either 'node'," + " or the name of a connector of the node".format(arg)) + + extra_vars = {} + if compute is not None: + for var_name, function in compute.items(): + + # Get the names of the lambda + argument_names = list(inspect.signature(function).parameters) + + args = map(compute_argument_resolver, argument_names) + var_value = function(*args) + + extra_vars[var_name] = var_value + + return program_for_node(func, sdfg, state, node, extra_vars=extra_vars) + + doc = \ + """ +Pure implementation parsed with +:func:`~dace.libraries.onnx.op_implementations.utils.python_pure_op_implementation`. + +.. code :: python + +""" + doc += textwrap.indent(inspect.getsource(func), prefix=" ") + + PureImpl.__module__ = func.__module__ + PureImpl.__name__ = func.__name__ + PureImpl.__qualname__ = func.__qualname__ + PureImpl.__doc__ = doc + + return PureImpl diff --git a/dace/libraries/onnx/schema.py b/dace/libraries/onnx/schema.py new file mode 100644 index 0000000000..9d9ba3991d --- /dev/null +++ b/dace/libraries/onnx/schema.py @@ -0,0 +1,333 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ONNX Schema System for DaCe. + +This module provides a Python representation layer for ONNX protobuf schemas, +enabling type-safe interaction with ONNX operations in DaCe. It handles: + +- Converting ONNX protobuf definitions to Python classes +- Type validation and constraint checking for ONNX operations +- Attribute and parameter schema definitions +- Automatic mapping between ONNX types and DaCe types + +Key Components: +- onnx_representation: Decorator for creating Python representations of ONNX protobufs +- ONNXSchema: Complete schema for an ONNX operation +- ONNXAttribute: Attribute definitions (e.g., kernel_shape, strides) +- ONNXParameter: Input/output parameter specifications +- ONNXTypeConstraint: Type constraints for operation parameters +- Enums: ONNXAttributeType, ONNXParameterType for type classification + +The schema system enables: +- Compile-time validation of ONNX operations +- Automatic property generation from schemas +- Type-safe conversion between ONNX and DaCe representations +- Integration with DaCe's property system + +Example: + @onnx_representation(onnx.TensorProto) + class ONNXTensor: + dims: List[int] + data_type: int +""" + +from itertools import chain +from typing import List + +import aenum +import numpy as np +import onnx + +import dace +from dace import config +from dace.dtypes import typeclass +from dace.libraries.onnx.converters import convert_onnx_proto, get_proto_attr, onnx_type_str_to_typeclass +from dace.properties import DictProperty, ListProperty, Property, make_properties + +#: Global registry of known ONNX protobuf types and their Python representations +_KNOWN_ONNX_PROTOS = {} + + +def onnx_representation(represents, **mapping): + """Decorator for python representations of ONNX protobufs. + + The decorator will monkey patch in the following methods: + + * ``__init__`` - a constructor based on the class properties + * ``construct_from_onnx_proto`` + * ``construct_from_json`` + + :param represents: The ONNX protobuf type that the decorated class represents. + :param mapping: A mapping from class property names to either: + + * A string ``s`` - ``convert_onnx_attribute`` will be applied on the + protobuf attribute with the name ``s`` to get the property value. + * A function ``f`` - ``f`` will be called with the protobuf, and the + property value will be set to the return value of that call. + + If a property name is not present in ``mapping``, the property name + itself will be used to access the protobuf attribute. + """ + + def decorator(cls): + + cls = make_properties(cls) + + # initialize the mapping with identity + # this means that by default, we will read the property of the protobuf using the same name as the property name + for name, _ in cls.__properties__.items(): + if name not in mapping: + mapping[name] = name + + def __init__(self, *args, **kwargs): + args = list(args) + for name, prop in self.__properties__.items(): + if len(args) > 0: + # try to init all the positional args first + setattr(self, name, args.pop(0)) + else: + # then try kwargs + setattr(self, name, kwargs[name]) + self._represents = represents + if hasattr(self, "validate"): + self.validate() + + @classmethod + def from_onnx_proto(cls, onnx_proto): + + if type(onnx_proto) is not represents: + raise ValueError("Unexpected protobuf '{}' (type {}), expected protobuf of type {}".format( + onnx_proto, type(onnx_proto), represents)) + + constructor_args = {} + for name, _ in cls.__properties__.items(): + if type(mapping[name]) is str: + # if the value of the mapping for that property is a string, read the attribute with that name + constructor_args[name] = convert_onnx_proto(get_proto_attr(onnx_proto, mapping[name])) + else: + # the value of the mapping should be a function, apply it to the onnx_proto + constructor_args[name] = mapping[name](onnx_proto) + + return cls(**constructor_args) + + @classmethod + def from_json(cls, json, context=None): + + constructor_args = { + name: prop.from_json(json[name] if name in json else prop.default) + for name, prop in cls.__properties__.items() + } + return cls(**constructor_args) + + def to_json(self): + serialized = dace.serialize.all_properties_to_json(self) + serialized["type"] = cls.__name__ + return serialized + + cls.__init__ = __init__ + + # the first line of the init docstring contains the signature of the method. This will be picked up by sphinx + # and means that the generated sphinx docs have a proper signature, and not just *args, **kwargs. + init_docstring = "__init__({})\n\n".format(", ".join(name + "=" + repr(prop._default) + for name, prop in cls.__properties__.items())) + + def get_prop_docstring(name, prop): + return ":param {}: {}\n:type {}: ``{}``, default ``{}``".format( + name, prop.__doc__, name, + prop._dtype.__name__ if prop._dtype is not None else type(prop._default).__name__, repr(prop._default)) + + init_docstring += "\n".join(get_prop_docstring(name, prop) for name, prop in cls.__properties__.items()) + + cls.__init__.__doc__ = init_docstring + + cls.from_onnx_proto = from_onnx_proto + cls.from_json = from_json + cls.to_json = to_json + from_onnx_proto.__func__.__doc__ = " Construct an object from an ONNX proto of type ``{}``. ".format(represents) + from_json.__func__.__doc__ = " Construct an object json ".format(represents) + to_json.__doc__ = " Serialize to json ".format(represents) + + # register so that we're able to load it + _KNOWN_ONNX_PROTOS[represents] = cls + + return cls + + return decorator + + +class ONNXParameterType(aenum.AutoNumberEnum): + Single = () #: single/required parameters + Optional = () #: optional parameters + Variadic = () #: variadic parameters + + +@onnx_representation(onnx.defs.OpSchema.FormalParameter, + type_str='type_str', + param_type='option', + homogeneous="is_homogeneous") +class ONNXParameter: + """ Python representation of an ONNX parameter. """ + + name = Property(dtype=str, desc="The parameter name") + description = Property(dtype=str, desc="A description of the parameter") + type_str = Property(dtype=str, desc="The type string of this parameter") + param_type = Property(choices=ONNXParameterType, + desc="The type of the this parameter", + default=ONNXParameterType.Single) + homogeneous = Property(dtype=bool, desc="Whether this parameter is homogeneous") + + def __repr__(self): + return "{} ({})".format(self.name, str(self.param_type)) + + +class ONNXAttributeType(aenum.AutoNumberEnum): + Int = () #: Integer (python representation is ``int``) + Float = () #: Float (python representation is ``float``) + String = () #: String (python representation is ``str``) + Ints = () #: Ints (python representation is ``List`` [``int``]) + Floats = () #: Floats (python representation is ``List`` [``float``]) + Strings = () #: Strings (python representation is ``List`` [``str``]) + Tensor = () #: Tensor (python representation is ``numpy.ndarray``) + Unsupported = () #: Any unsupported attribute type + + +_ATTR_TYPE_TO_PYTHON_TYPE = { + ONNXAttributeType.Int: int, + ONNXAttributeType.Ints: int, + ONNXAttributeType.Float: float, + ONNXAttributeType.Floats: float, + ONNXAttributeType.String: str, + ONNXAttributeType.Strings: str, + ONNXAttributeType.Tensor: np.ndarray +} + + +@onnx_representation(onnx.defs.OpSchema.Attribute, attribute_type='type') +class ONNXAttribute: + """ Python representation of an ONNX attribute. """ + + name = Property(dtype=str, desc="The attribute name") + description = Property(dtype=str, desc="A description this attribute") + required = Property(dtype=bool, desc="Whether this attribute is required") + attribute_type = Property(choices=ONNXAttributeType, + desc="The type of this attribute", + default=ONNXAttributeType.Int) + default_value = Property(dtype=None, desc="The default value of this attribute", default=None, allow_none=True) + + def validate(self): + if self.required and self.attribute_type == ONNXAttributeType.Unsupported: + raise NotImplementedError("Required attribute '{}' has an unsupported type".format(self.name)) + + def __repr__(self): + return self.name + + +@onnx_representation( + onnx.defs.OpSchema.TypeConstraintParam, + type_str='type_param_str', + types=lambda proto: list( + filter(lambda x: x is not None, map(onnx_type_str_to_typeclass, get_proto_attr(proto, "allowed_type_strs"))))) +class ONNXTypeConstraint: + """ Python representation of an ONNX type constraint. """ + + type_str = Property(dtype=str, desc="The type parameter string") + types = ListProperty(element_type=typeclass, + desc="The possible types. Note that only tensor types are currently supported.") + + def __repr__(self): + return self.type_str + + +@onnx_representation( + onnx.defs.OpSchema, + inputs=lambda proto: list(map(convert_onnx_proto, get_proto_attr(proto, "inputs"))), + outputs=lambda proto: list(map(convert_onnx_proto, get_proto_attr(proto, "outputs"))), + attributes=lambda proto: { + str(k): convert_onnx_proto(v) + for k, v in get_proto_attr(proto, "attributes").items() + }, + type_constraints=lambda proto: + {str(cons.type_param_str): convert_onnx_proto(cons) + for cons in get_proto_attr(proto, "type_constraints")}) +class ONNXSchema: + """Python representation of an ONNX schema""" + + name = Property(dtype=str, desc="The operator name") + domain = Property(dtype=str, desc="The operator domain") + doc = Property(dtype=str, desc="The operator's docstring") + since_version = Property(dtype=int, desc="The version of the operator") + attributes = DictProperty(key_type=str, + value_type=ONNXAttribute, + desc="The operator attributes. Keys should contain the name of the attribute, and values " + "should have type :class:`~dace.libraries.onnx.ONNXAttribute`.") + type_constraints = DictProperty( + key_type=str, + value_type=ONNXTypeConstraint, + desc="The type constraints for inputs and outputs. Keys should contain the type string of the constraint, " + "values should have type :class:`~dace.libraries.onnx.ONNXTypeConstraint`.") + inputs = ListProperty(element_type=ONNXParameter, + desc="The operator input parameter descriptors. Entries should have type" + " :class:`~dace.libraries.onnx.ONNXParameter`.") + outputs = ListProperty(element_type=ONNXParameter, + desc="The operator output parameter descriptors. Entries should have type" + " :class:`~dace.libraries.onnx.ONNXParameter`.") + + def __repr__(self): + return self.domain + "." + self.name + + def non_variadic_inputs(self) -> List[str]: + return [i.name for i in self.inputs if i.param_type is not ONNXParameterType.Variadic] + + def variadic_inputs(self) -> List[str]: + return [i.name for i in self.inputs if i.param_type is ONNXParameterType.Variadic] + + def non_variadic_outputs(self) -> List[str]: + return [i.name for i in self.outputs if i.param_type is not ONNXParameterType.Variadic] + + def variadic_outputs(self) -> List[str]: + return [i.name for i in self.outputs if i.param_type is ONNXParameterType.Variadic] + + def validate(self): + # check all parameters with a type str have a entry in the type constraints + for param in chain(self.inputs, self.outputs): + if param.type_str not in self.type_constraints: + # some operators put a type descriptor here. for those, we will try to insert a new type constraint + cons_name = param.name + "_constraint" + if cons_name in self.type_constraints: + raise ValueError( + "Attempted to insert new type constraint, but the name already existed. Please open an issue.") + parsed_typeclass = onnx_type_str_to_typeclass(param.type_str) + + if parsed_typeclass is None: + if config.Config.get_bool('debugprint'): + print("Could not parse typeStr '{}' for parameter '{}'".format(param.type_str, param.name)) + + cons = ONNXTypeConstraint(cons_name, [parsed_typeclass] if parsed_typeclass is not None else []) + self.type_constraints[cons_name] = cons + param.type_str = cons_name + + # check for required parameters with no supported type + for param in chain(self.inputs, self.outputs): + if ((param.param_type == ONNXParameterType.Single or param.param_type == ONNXParameterType.Variadic) + and len(self.type_constraints[param.type_str].types) == 0): + raise NotImplementedError("None of the types for parameter '{}' are supported".format(param.name)) + + # check that all variadic parameter names do not contain "__" + for param in chain(self.inputs, self.outputs): + if param.param_type == ONNXParameterType.Variadic and "__" in param.name: + raise ValueError( + "Unsupported parameter name '{}': variadic parameter names must not contain '__'".format( + param.name)) + + # check that all inputs and outputs have unique names + seen = set() + for param in self.inputs: + if param.name in seen: + raise ValueError("Got duplicate input parameter name '{}'".format(param.name)) + seen.add(param.name) + + seen = set() + for param in self.outputs: + if param.name in seen: + raise ValueError("Got duplicate output parameter name '{}'".format(param.name)) + seen.add(param.name) diff --git a/dace/libraries/torch/__init__.py b/dace/libraries/torch/__init__.py new file mode 100644 index 0000000000..8cc16a958d --- /dev/null +++ b/dace/libraries/torch/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +DaCe PyTorch Integration Library. + +This module provides integration between DaCe (Data-Centric Parallel Programming) +and PyTorch, enabling: +- Compilation of PyTorch operations to optimized DaCe SDFGs +- Interoperability between PyTorch tensors and DaCe arrays +- Support for both CPU (PyTorch) and GPU (PyTorchGPU) execution +- DLPack-based zero-copy tensor sharing + +The main exports are environment classes that define the PyTorch runtime +dependencies and configuration for code generation. +""" + +try: + from .environments import PyTorch, PyTorchGPU + __all__ = ["PyTorch", "PyTorchGPU"] +except ImportError: + # PyTorch not available + PyTorch = None + PyTorchGPU = None + __all__ = [] diff --git a/dace/libraries/torch/dispatchers/__init__.py b/dace/libraries/torch/dispatchers/__init__.py new file mode 100644 index 0000000000..33b5aeecee --- /dev/null +++ b/dace/libraries/torch/dispatchers/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +PyTorch Dispatchers for DaCe Modules. + +This module provides different dispatcher implementations for executing DaCe SDFGs +from PyTorch. Dispatchers handle: +- Compiling SDFGs to native code +- Initializing runtime state and memory +- Converting between PyTorch tensors and DaCe arrays +- Calling forward and backward SDFG functions +- Managing the integration with PyTorch's autograd system + +Available dispatchers: +- CTypes dispatcher: Uses ctypes for direct C function calls +- C++ PyTorch extension: Registers as a native PyTorch extension with custom autograd +""" + +from .common import DaceTorchFunction +from .cpp_torch_extension import register_and_compile_torch_extension +from .ctypes_module import get_ctypes_dispatcher + +__all__ = ["DaceTorchFunction", "register_and_compile_torch_extension", "get_ctypes_dispatcher"] diff --git a/dace/libraries/torch/dispatchers/common.py b/dace/libraries/torch/dispatchers/common.py new file mode 100644 index 0000000000..80ee1d9f28 --- /dev/null +++ b/dace/libraries/torch/dispatchers/common.py @@ -0,0 +1,112 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Common utilities for PyTorch-DaCe dispatchers. + +This module provides shared functionality for different dispatcher implementations, +including: +- SDFG compilation and initialization +- Argument list extraction and processing +- State management for forward and backward passes +- Integration with PyTorch's autograd system +""" + +import dataclasses +from typing import Callable, List, Tuple, Union + +import dace +import torch +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.libraries.onnx.converters import clean_onnx_name +from dace.frontend.ml.onnx.importer import create_output_array + + +@dataclasses.dataclass +class DaceTorchFunction: + """ + An initialized, callable function for a DaceModule and its associated state. + + This dataclass encapsulates a compiled DaCe module with its runtime state, + providing a callable interface for PyTorch integration. + + Attributes: + function: The PyTorch callable function that executes the SDFG. + compiled_sdfgs: The compiled SDFGs holding their runtime states. + ptr: Pointers to the initialized SDFG state handles. These must be + passed as the first arguments to the function. + """ + function: Callable + compiled_sdfgs: List[CompiledSDFG] + ptr: List[torch.Tensor] + + +def get_arglist(module: 'dace.frontend.ml.torch.DaceModule') -> Tuple[List[str], List[str]]: + """Get the list of forward-pass argument names for a module. + + :param module: The DaCe module to extract argument names from. + :return: A tuple of (input_names, output_names) where each is a list of cleaned + argument names suitable for use in generated code. + """ + + arglist = [clean_onnx_name(input_name) for input_name in module.dace_model.inputs] + outputs = [clean_onnx_name(output_name) for output_name in module.dace_model.outputs] + return arglist, outputs + + +def compile_and_init_sdfgs( + module: 'dace.frontend.ml.torch.DaceModule', dummy_inputs +) -> Union[Tuple[CompiledSDFG, torch.Tensor], Tuple[CompiledSDFG, torch.Tensor, CompiledSDFG, torch.Tensor]]: + """Compile SDFGs and initialize them using the provided dummy inputs. + + This function compiles the forward pass SDFG and optionally the backward pass + SDFG if the module has automatic differentiation enabled. It initializes both + SDFGs with the appropriate tensors and parameters. + + :param module: The DaCe module to compile SDFGs for. + :param dummy_inputs: The dummy inputs to use for shape inference and initialization. + :return: If the module has no backward pass: (compiled_sdfg, state_ptr). + If the module has a backward pass: (compiled_fwd_sdfg, fwd_state_ptr, + compiled_bwd_sdfg, bwd_state_ptr). Where state_ptr is a torch.Tensor + containing the pointer to the SDFG state. + """ + + compiled: CompiledSDFG = module.dace_model.compile_and_init() + # Construct the arguments and initialize the SDFG + args = tuple(dummy_inputs) + module._call_params() + args = tuple(arg.detach() for arg in args) + inputs, symbols, outputs = module.dace_model._call_args(args=args, kwargs={}) + + if module.backward: + forwarded_transients = { + name: + create_output_array(symbols, desc, use_torch=True, zeros=True) + if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[name] + for name, desc in module._ad_inp_arrs.items() + } + else: + forwarded_transients = {} + + all_kwargs = {**inputs, **outputs, **symbols, **forwarded_transients, **module.dace_model.initialized_parameters} + + compiled.initialize(**all_kwargs) + for _, hook in module.post_compile_hooks.items(): + hook(compiled) + handle_ptr = torch.tensor([compiled._libhandle.value]).squeeze(0) + + if module.backward: + # Compile and initialize the backward_sdfg + compiled_bwd: CompiledSDFG = module.backward_sdfg.compile() + + required_grads = { + bwd_name: create_output_array(symbols, compiled_bwd.sdfg.arrays[bwd_name], use_torch=True, zeros=True) + for _, bwd_name in module._ad_result.required_grad_names.items() + } + given_grads = { + bwd_name: create_output_array(symbols, compiled_bwd.sdfg.arrays[bwd_name], use_torch=True, zeros=True) + for _, bwd_name in module._ad_result.given_grad_names.items() + } + + compiled_bwd.initialize(**required_grads, **given_grads, **forwarded_transients) + bwd_handle_ptr = torch.tensor([compiled_bwd._libhandle.value]).squeeze(0) + return compiled, handle_ptr, compiled_bwd, bwd_handle_ptr + else: + return compiled, handle_ptr diff --git a/dace/libraries/torch/dispatchers/cpp_torch_extension.py b/dace/libraries/torch/dispatchers/cpp_torch_extension.py new file mode 100644 index 0000000000..f7449df2f4 --- /dev/null +++ b/dace/libraries/torch/dispatchers/cpp_torch_extension.py @@ -0,0 +1,699 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +"""Code generation for PyTorch C++ dispatched operators.""" +import copy +import dataclasses +from distutils import sysconfig +import hashlib +import itertools +import operator +import os +import sys +from typing import List, Tuple, Callable, Optional, Dict, Union + +import dace.library +import numpy as np +import torch +from torch.utils.cpp_extension import load as torch_load +import dace +from dace import config, dtypes as dt, data +from dace.codegen import targets, compiler +from dace.codegen.codeobject import CodeObject +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.common import sym2cpp, platform_library_name + +from dace.autodiff import BackwardResult +from dace.libraries.torch.environments import PyTorch + +from dace.libraries.torch.dispatchers.common import DaceTorchFunction, compile_and_init_sdfgs, get_arglist + +_REPLACED_CTYPES = {dace.int64: "int64_t", dace.uint64: "uint64_t", dace.float16: "at::Half"} + + +def torch_ctype(dtype: dace.typeclass) -> str: + """Convert a DaCe type to the corresponding PyTorch C++ type string. + + :param dtype: The DaCe typeclass to convert. + :return: The corresponding C++ type string for PyTorch. + """ + if isinstance(dtype, dace.pointer): + # assuming pointers are 64 bit + ctype = "int64_t" + elif dtype in _REPLACED_CTYPES: + ctype = _REPLACED_CTYPES[dtype] + else: + ctype = dtype.ctype + return ctype + + +_TYPECLASS_TO_TORCH_DTYPE_STR = { + dt.bool: "kBool", + dt.int8: "kInt8", + dt.uint8: "kUInt8", + dt.int16: "kInt16", + dt.int32: "kInt32", + dt.int64: "kInt64", + dt.float16: "kFloat16", + dt.float32: "kFloat32", + dt.float64: "kFloat64", + dt.complex64: "kComplexFloat", + dt.complex128: "kComplexDouble", +} + + +def typeclass_to_torch_cpp_type(type: dace.typeclass) -> str: + """Convert a DaCe typeclass to PyTorch C++ tensor type string. + + :param type: The DaCe typeclass to convert. + :return: The corresponding PyTorch tensor type string (e.g., 'kFloat32'). + """ + if isinstance(type, dace.pointer): + # assuming pointers are 64 bit + return "kInt64" + else: + return _TYPECLASS_TO_TORCH_DTYPE_STR[type] + + +def tensor_init_for_desc(name: str, desc: data.Data, clean_weights: Dict[str, torch.Tensor], zeros=True) -> str: + """Emit the initialization code for a descriptor. + + :param name: The name of the tensor. + :param desc: The data descriptor. + :param clean_weights: Dictionary of constant weights. + :param zeros: Whether to initialize with zeros (True) or empty (False). + :return: C++ code string for tensor initialization. + """ + + # Check if name is in clean_weights + if name in clean_weights: + # Get the tensor from clean_weights + weight_tensor = clean_weights[name] + + # Convert the tensor to a C++ initializer list format + # Flatten the tensor and convert to list + values = weight_tensor.flatten().tolist() + + # Format the values based on the data type + def format_value(v, dtype): + if dtype in [dt.float32, dt.float16]: + return f'{v}f' + elif dtype == dt.float64: + return str(v) + elif dtype in [dt.int8, dt.int16, dt.int32, dt.int64, dt.uint8]: + return str(int(v)) + elif dtype == dt.bool: + return str(v).lower() + else: + return str(v) + + # Format the values as a C++ initializer list + values_str = ', '.join(format_value(v, desc.dtype) for v in values) + + return f"""\ + Tensor {name} = torch::from_blob( + new float[{len(values)}]{{{values_str}}}, + {{{', '.join(str(s) for s in desc.shape)}}}, + torch::TensorOptions() + .dtype(torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .device(torch::{'kCUDA' if desc.storage in dace.dtypes.GPU_STORAGES else 'kCPU'}) + .layout(torch::kStrided)).clone(); + """ + else: + # Initialize with zeros or empty + return f"""\ + Tensor {name} = torch::{'zeros' if zeros else 'empty'}( + {{{', '.join(str(s) for s in desc.shape)}}}, + torch::TensorOptions() + .dtype(torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .device(torch::{'kCUDA' if desc.storage in dace.dtypes.GPU_STORAGES else 'kCPU'}) + .layout(torch::kStrided)); + """ + + +def initialize_outputs_code(module: 'dace.frontend.ml.torch.DaceModule', output_names: List[str], + clean_weights: Dict[str, torch.Tensor]) -> str: + """Generate the code that initializes the output tensors. + + :param module: The module + :param output_names: The output names of the SDFG. + :param clean_weights: Dictionary of constant weights + :return: The code + """ + arglist = module.sdfg.arglist() + code = "" + for name in sorted(output_names): + code += tensor_init_for_desc(name, arglist[name], clean_weights) + + return code + + +def argument_codegen(sdfg: dace.SDFG, + clean_weights: Dict[str, torch.Tensor], + input_names: List[str], + output_names: List[str], + guard_contiguous: Optional[List[str]] = None) -> Tuple[str, str, str]: + """Generate the code that grabs the pointers of inputs and outputs. + + The names of the tensors will match the SDFG tensor names. Tensors that are not created by us (i.e. inputs) + should be named {sdfg_name}_ first, and then .contiguous() will be called on them to yield the tensor that we + require. This is the case for all tensors in ``guard_contiguous``. + + :param sdfg: The SDFG to generate code for + :param clean_weights: The constant weights of the SDFG. + :param input_names: Names of inputs to the torch function. + :param output_names: Names of outputs to the torch function. + :param guard_contiguous: A subset of input_names to call .contiguous on. If None, all input names will be + guarded. + :return: The code for initializing the argument, the SDFG arguments in order, and the init call arguments + """ + arglist = sdfg.arglist() + + guard_contiguous = set(guard_contiguous or input_names) + + assert set(input_names).issubset(arglist.keys()), \ + f"Input names {set(input_names).difference(arglist.keys())} are not SDFG arguments {arglist.keys()}" + + # Initialize the inputs and outputs + ptr_init_code = "\n// Setup input and output pointers\n" + for name in sorted(input_names): + tctype = torch_ctype(arglist[name].dtype) + dctype = arglist[name].dtype + + if isinstance(arglist[name], data.Array) or dt.can_access(dt.ScheduleType.GPU_Device, arglist[name].storage): + if name in guard_contiguous: + if config.Config.get_bool('debugprint'): + ptr_init_code += f""" + if (!{name}_.is_contiguous()) {{ + fprintf(stderr, "{name} was not contiguous!"); + }} + """ + ptr_init_code += '\n' + f"Tensor {name} = {name}_.contiguous();" + + ptr_init_code += '\n' + f"{dctype} *{name}_ptr = reinterpret_cast<{dctype}*>({name}.data_ptr<{tctype}>());" + + elif isinstance(arglist[name], data.Scalar): + if name in guard_contiguous: + ptr_init_code += '\n' + f"{dctype} {name}_ptr = static_cast<{dctype}>({name}_.item().to<{tctype}>());" + else: + ptr_init_code += '\n' + f"{dctype} {name}_ptr = static_cast<{dctype}>({name}.item().to<{tctype}>());" + else: + raise ValueError(f"Unsupported data type {type(arglist[name])} for descriptor {name}") + + ptr_init_code += '\n' + + # Outputs and backward arrays + ptr_init_code += '\n'.join( + f"{arglist[name].dtype.ctype} *{name}_ptr = reinterpret_cast<{arglist[name].dtype.ctype}*>" + f"({name}.data_ptr<{torch_ctype(arglist[name].dtype)}>());" for name in sorted(output_names)) + ptr_init_code += "\n// Setup constant arguments\n" + + all_access_nodes = set() + for state in sdfg.nodes(): + all_access_nodes |= set(n.data for n in state.data_nodes()) + + # Initialize all remaining parameters + remaining = set(arglist).difference(itertools.chain(input_names, output_names)) + for name in sorted(remaining): + # Remaining args must be constants + if name not in clean_weights: + raise ValueError(f"Cannot generate PyTorch module C++ code: SDFG argument {name} is not an input or output" + f" of the PyTorch Module, and not a constant.") + + value = clean_weights[name] + ptr_init_code += f"{constant_initializer_code(name, arglist[name], value)}\n" + + arguments = ", ".join(f"{n}_ptr" for n in arglist) + init_arguments = ", ".join(f"{n}_ptr" for n, desc in arglist.items() if isinstance(desc, data.Scalar)) + + return ptr_init_code, arguments, init_arguments + + +def item_to_cpp_literal(item) -> str: + """Convert a numpy item to a C++ literal string. + + :param item: The numpy item to convert. + :return: The C++ literal representation as a string. + """ + dtype = str(item.dtype) + if np.isneginf(item): + return "-std::numeric_limits::infinity()" + if np.isposinf(item): + return "std::numeric_limits::infinity()" + if dtype == "float32": + return f"{item}f" + elif dtype == "bool": + return f"{str(item).lower()}" + elif dtype == "int64": + return f"{item}l" + elif dtype == "float16": + ctype = dace.dtypes._CTYPES[item.dtype.type] + return f"(({ctype}){item})" + elif dtype in ["float64", "int32", "int16", "int8"]: + return str(item) + else: + raise ValueError(f"Unsupported tensor type {item.dtype}") + + +def constant_initializer_code(name: str, desc: data.Data, value) -> str: + """Generate C++ code for initializing a constant value. + + :param name: The name of the constant. + :param desc: The data descriptor. + :param value: The constant value. + :return: C++ code string for constant initialization. + """ + gpu_storage = dt.can_access(dt.ScheduleType.GPU_Device, desc.storage) + gpu_storage = False + if desc.total_size == 0: + return f"{desc.dtype.ctype} *{name}_ptr = nullptr;" + elif isinstance(desc, data.Array) or gpu_storage: + numpyval = value.cpu().numpy() + if len(numpyval.shape) == 0: + numpyval = numpyval.reshape((1, )) + iterator = np.nditer(numpyval, order="C") + gpu_copy_code = f""" + Tensor {name} = torch::from_blob({name}_ptr_cpu, {{{', '.join(sym2cpp(s) for s in desc.shape)}}}, + {{{', '.join(sym2cpp(s) for s in desc.strides)}}}, torch::{typeclass_to_torch_cpp_type(desc.dtype)}) + .to(torch::kCUDA); + {desc.dtype.ctype} *{name}_ptr = reinterpret_cast<{desc.dtype.ctype}*>({name}.data_ptr<{torch_ctype(desc.dtype)}>()); + """ + return f""" + {desc.dtype.ctype} {name}_ptr{'_cpu' if gpu_storage else ''}[{sym2cpp(desc.total_size)}] = + {{{', '.join(item_to_cpp_literal(e) for e in iterator)}}}; + {gpu_copy_code if gpu_storage else ""} + """ + elif isinstance(desc, data.Scalar): + if str(value.item()) == "-inf": + return f"{desc.dtype.ctype} {name}_ptr = -std::numeric_limits<{desc.dtype.ctype}>::infinity();" + elif str(value.item()) == "inf": + return f"{desc.dtype.ctype} {name}_ptr = std::numeric_limits<{desc.dtype.ctype}>::infinity();" + if desc.dtype.ctype == "bool": + # Special case for bools + bool_str = "true" if value.item() else "false" + return f"{desc.dtype.ctype} {name}_ptr = {bool_str};" + return f"{desc.dtype.ctype} {name}_ptr = {str(value.item())};" + else: + raise ValueError("Unsupported data descriptor") + + +def return_type_str(outputs: List[str]) -> str: + """Generate the return type string for the given outputs. + + :param outputs: List of output names. + :return: The C++ return type string. + """ + return f"""{"Tensor" if len(outputs) == 1 else f"variable_list"}""" + + +def save_non_inputs_outputs(names: List[str]): + """Generate code to save non-input/output tensors for backward pass. + + :param names: List of tensor names to save. + :return: C++ code string for saving tensors. + """ + return "\n".join(f'ctx->saved_data["{n}"] = {n};' for n in names) + + +def recover_saved_inputs_outputs(saved_inputs_outputs: List[str], other_saved: List[str]): + """Generate code to recover saved tensors in backward pass. + + :param saved_inputs_outputs: List of saved input/output tensor names. + :param other_saved: List of other saved tensor names. + :return: C++ code string for recovering saved tensors. + """ + code = "" + if saved_inputs_outputs: + code += "auto saved = ctx->get_saved_variables();\n" + for i, n in enumerate(saved_inputs_outputs): + code += f"\nauto {n} = saved[{i}];" + + for n in other_saved: + code += f'\nauto {n} = ctx->saved_data["{n}"].toTensor();' + + return code + + +def setup_grad_values(backward_result: BackwardResult, sdfg: dace.SDFG, outputs: List[str], + clean_weights: Dict[str, torch.Tensor]) -> str: + """Generate code to setup gradient values for backward pass. + + :param backward_result: The backward pass result containing gradient information. + :param sdfg: The SDFG. + :param outputs: List of output names. + :param clean_weights: Dictionary of constant weights. + :return: C++ code string for gradient setup. + """ + code = "// input grads" + for param_name, grad_name in sorted(backward_result.required_grad_names.items()): + zero_init = backward_result.zero_init.get(param_name, True) + code += "\n" + tensor_init_for_desc(grad_name, sdfg.arrays[grad_name], clean_weights, zeros=zero_init) + + code += "// output grads" + for i, o in enumerate(outputs): + grad_name = backward_result.given_grad_names[o] + code += f'\nauto {grad_name}_ = grad_outputs[{i}];' + + return code + + +def code_for_backward_function(module: 'dace.frontend.ml.torch.DaceModule', forward_sdfg: dace.SDFG, + backward_sdfg: dace.SDFG, backward_result: BackwardResult, + forwarded_arrays: Dict[str, data.Data]) -> str: + """Generate C++ code for a differentiable PyTorch function. + + :param module: The DaCe module. + :param forward_sdfg: The forward SDFG. + :param backward_sdfg: The backward SDFG. + :param backward_result: The backward pass result. + :param forwarded_arrays: Arrays forwarded from forward to backward pass. + :return: Complete C++ code string for the differentiable function. + """ + + inputs, outputs = get_arglist(module) + sdfg_name = forward_sdfg.name + + ret_str = return_type_str(outputs) + + outputs_with_forwarded_outputs = copy.deepcopy(outputs) + outputs_with_forwarded_outputs.extend(n for n in forwarded_arrays if n not in inputs and n not in outputs) + + fwd_ptr_init_code, fwd_sdfg_call_arguments, _ = argument_codegen(forward_sdfg, module.dace_model.clean_weights, + inputs, outputs_with_forwarded_outputs) + + # Inputs are given_grads + forwarded_outputs + bwd_inputs = list(backward_result.given_grad_names.values()) + list(forwarded_arrays) + + # Outputs are required grads + bwd_outputs = list(backward_result.required_grad_names.values()) + + bwd_ptr_init_code, bwd_sdfg_call_arguments, _ = argument_codegen(backward_sdfg, + module.dace_model.clean_weights, + bwd_inputs, + bwd_outputs, + guard_contiguous=list( + backward_result.given_grad_names.values())) + + # Saved inputs/outputs + saved_io_for_backward = [n for n in forwarded_arrays if n in inputs or n in outputs] + other_saved_for_backward = [n for n in forwarded_arrays if n not in inputs and n not in outputs] + return f""" +{get_header(forward_sdfg, backward_sdfg, inputs, outputs, module.use_cuda)} +class {sdfg_name}Function : public torch::autograd::Function<{sdfg_name}Function> {{ + public: + static + {ret_str} + forward( + AutogradContext *ctx, + int64_t fwd_handle_ptr, int64_t bwd_handle_ptr, {", ".join(f"const Tensor& {name}_" for name in inputs)}) {{ + + at::AutoDispatchBelowADInplaceOrView g; + + // initialize outputs + {initialize_outputs_code(module, outputs_with_forwarded_outputs, module.dace_model.clean_weights)} + + {fwd_ptr_init_code} + + // get SDFG state handle + {forward_sdfg.name}Handle_t handle = reinterpret_cast<{forward_sdfg.name}Handle_t>(fwd_handle_ptr); + + + // call SDFG + __program_{forward_sdfg.name}(handle, {fwd_sdfg_call_arguments}); + + // save inputs/outputs for backward + { + f"ctx->save_for_backward({{{', '.join(f'{n}' for n in saved_io_for_backward)}}});" + if saved_io_for_backward else "" + } + + // save non-inputs/outputs + {save_non_inputs_outputs(other_saved_for_backward)} + + // save bwd handle + ctx->saved_data["bwd_handle"] = bwd_handle_ptr; + + // return to torch + return {f"{outputs[0]}" if len(outputs) == 1 + else f"{{{', '.join(o for o in outputs)}}}"}; + }} + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {{ + // recover bwd_handle_ptr + int64_t bwd_handle_ptr = ctx->saved_data.find("bwd_handle")->second.toInt(); + + // recover saved values + {recover_saved_inputs_outputs(saved_io_for_backward, other_saved_for_backward)} + + // create grad values + // NOTE, it might make sense take these from .grad() + {setup_grad_values(backward_result, backward_sdfg, outputs, module.dace_model.clean_weights)} + + {bwd_ptr_init_code} + + // get SDFG state handle + {backward_sdfg.name}Handle_t handle = reinterpret_cast<{backward_sdfg.name}Handle_t>(bwd_handle_ptr); + + // call bwd SDFG + __program_{backward_sdfg.name}(handle, {bwd_sdfg_call_arguments}); + + // return calculated grads in correct order + // first two grads are None (these are the grads for the handle ptrs) + return {{ + Tensor(), Tensor(), {', '.join(backward_result.required_grad_names[i] if i in backward_result.required_grad_names else 'Tensor()' for i in inputs )} + }}; +}} +}}; + +{ret_str} +{sdfg_name}_autograd(int64_t handle_ptr, int64_t bwd_handle_ptr, {",".join(f"const Tensor& {name}_" for name in inputs)}) {{ +return {sdfg_name}Function::apply( +handle_ptr, bwd_handle_ptr, {", ".join(f"{name}_" for name in inputs)} +); +}} + +TORCH_LIBRARY_IMPL(dace_{sdfg_name}, Autograd{'CUDA' if module.use_cuda else 'CPU'}, m) {{ +m.impl("{sdfg_name}", {sdfg_name}_autograd); +}} +""" + + +def code_for_module(module: 'dace.frontend.ml.torch.DaceModule', compiled_sdfg: CompiledSDFG) -> str: + """Generate the code for an operator that calls the SDFGs in the module. + + :param module: The module. + :param compiled_sdfg: The compiled SDFG. + """ + + inputs, outputs = get_arglist(module) + sdfg_name = compiled_sdfg.sdfg.name + + ret_str = return_type_str(outputs) + ptr_init_code, sdfg_call_arguments, init_arguments = argument_codegen(compiled_sdfg.sdfg, + module.dace_model.clean_weights, inputs, + outputs) + return f""" +{get_header(compiled_sdfg.sdfg, None, inputs, outputs, module.use_cuda)} + +// Function definition +{ret_str} +{sdfg_name}(int64_t handle_ptr, {",".join(f"const Tensor& {name}_" for name in inputs)}) {{ + + // Initialize outputs + {initialize_outputs_code(module, outputs, module.dace_model.clean_weights)} + + {ptr_init_code} + + // Get SDFG state handle + {sdfg_name}Handle_t handle = reinterpret_cast<{sdfg_name}Handle_t>(handle_ptr); + + // Call SDFG + __program_{sdfg_name}(handle, {sdfg_call_arguments}); + + // Return to torch + return {f"{outputs[0]}" if len(outputs) == 1 + else f"{{{', '.join(o for o in outputs)}}}"}; +}} + +TORCH_LIBRARY_IMPL(dace_{sdfg_name}, {'CUDA' if module.use_cuda else 'CPU'}, m) {{ + m.impl("{sdfg_name}", {sdfg_name}); +}} + """ + + +def get_header(fwd_sdfg: dace.SDFG, bwd_sdfg: Optional[dace.SDFG], inputs, outputs, use_cuda: bool) -> str: + """Generate the C++ header code for the PyTorch extension. + + :param fwd_sdfg: The forward SDFG. + :param bwd_sdfg: The backward SDFG (optional). + :param inputs: List of input names. + :param outputs: List of output names. + :param use_cuda: Whether CUDA is used. + :return: C++ header code string. + """ + return f""" +#include +#include +#include "{fwd_sdfg.name}.h" +{"" if bwd_sdfg is None else f'#include "{bwd_sdfg.name}.h"'} +using torch::Tensor; +using torch::DeviceType; +using torch::autograd::tensor_list; +using torch::autograd::variable_list; +using torch::autograd::AutogradContext; + +TORCH_LIBRARY(dace_{fwd_sdfg.name}, m) {{ + m.def("{fwd_sdfg.name}(int handle_ptr,{"int bwd_handle_ptr," if bwd_sdfg else ""} {", ".join('Tensor ' + arg for arg in inputs)}) -> {'Tensor' if len(outputs) == 1 else 'Tensor[]'}"); +}} +""" + + +def _torch_ext_root() -> str: + """Resolve the torch extensions root without using private PyTorch APIs.""" + env = os.environ.get("TORCH_EXTENSIONS_DIR") + if env: + return env + + return os.path.join(os.path.expanduser("~"), ".cache", "torch_extensions") + + +def register_and_compile_torch_extension(module: 'dace.frontend.ml.torch.DaceModule', + dummy_inputs) -> DaceTorchFunction: + """Get a torch callable for the module. This will compile the SDFG, compile a PyTorch C++ operator, register it + with PyTorch and return the function that calls it. + + This function handles code generation for both the forward and backward pass. + + :param module: The module. + :param dummy_inputs: Dummy inputs to initialize the model with. + :return: The callable function for the SDFG. + """ + + # Build the SDFG + # Set all states to not-sync + for state in module.sdfg.nodes(): + state.nosync = True + + environments = { + PyTorch.full_class_path(), + } + if module.backward: + compiled, handle_ptr, compiled_bwd, bwd_handle_ptr = compile_and_init_sdfgs(module, dummy_inputs) + compiled_sdfgs = [compiled, compiled_bwd] if compiled_bwd is not None else [compiled] + ptrs = [handle_ptr, bwd_handle_ptr] if compiled_bwd is not None else [handle_ptr] + if compiled_bwd is not None: + environments.add(get_env_for_sdfg(compiled_bwd).full_class_path()) + bwd_sdfg = compiled_bwd.sdfg + code = code_for_backward_function(module, compiled.sdfg, bwd_sdfg, module._ad_result, module._ad_inp_arrs) + else: + bwd_sdfg = module.backward_sdfg + code = code_for_module(module, compiled) + else: + compiled, handle_ptr = compile_and_init_sdfgs(module, dummy_inputs) + compiled_sdfgs = [compiled] + ptrs = [handle_ptr] + code = code_for_module(module, compiled) + + environments.add(get_env_for_sdfg(compiled).full_class_path()) + code = indent_code(code) + + # ---------- Build the PyTorch module ---------- + base_libname = f"torch_{compiled.sdfg.name}" + program = CodeObject(base_libname, + code, + "cpp", + targets.cpu.CPUCodeGen, + f"Torch{module.sdfg_name}", + environments=environments) + + torch_module_build_path = os.path.join('.dacecache', base_libname) + parts = os.path.normpath(compiled.filename).split(os.sep) + sdfg_folder_name = parts[parts.index('.dacecache') + 1] + + # Treat the case where a hash is added to the SDFG folder dir + backward_sdfg_folder_name = f"{compiled.sdfg.name}_backward_{sdfg_folder_name.removeprefix(compiled.sdfg.name + '_')}" if sdfg_folder_name != compiled.sdfg.name else f"{compiled.sdfg.name}_backward" + compiler.generate_program_folder(None, [program], torch_module_build_path) + + include_path = os.path.abspath(os.path.join('.dacecache', sdfg_folder_name, "include")) + include_path_bwd = os.path.abspath(os.path.join('.dacecache', backward_sdfg_folder_name, "include")) + dace_include_path = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "runtime", "include")) + dace_include_onnx = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "libraries", "onnx", "include")) + dace_include_blas = os.path.abspath(os.path.join(os.path.dirname(dace.__file__), "libraries", "blas", "include")) + + code_path = os.path.join('.dacecache', sdfg_folder_name, "src", "cpu", f"{compiled.sdfg.name}.cpp") + code_path_bwd = os.path.join('.dacecache', backward_sdfg_folder_name, "src", "cpu", + f"{compiled.sdfg.name}_backward.cpp") + torch_code_path = os.path.join('.dacecache', base_libname, "src", "cpu", f"{base_libname}.cpp") + + sources = [p for p in [code_path, torch_code_path, code_path_bwd] if os.path.exists(p)] + + pid = os.getpid() + salt = hashlib.sha1(("".join(sources)).encode("utf-8")).hexdigest()[:8] + base_libname = f"torch_{compiled.sdfg.name}" + unique_name = f"{base_libname}_p{pid}_{salt}" + + build_root = _torch_ext_root() # <- uses our helper + unique_build_dir = os.path.join(build_root, unique_name) + os.makedirs(unique_build_dir, exist_ok=True) + + # We pass unique name + unique build directory to avoid FileBaton contention + torch_load( + name=unique_name, + sources=sources, + build_directory=unique_build_dir, + extra_cflags=["-g"], + extra_include_paths=[ + p for p in { + include_path, + include_path_bwd if os.path.exists(include_path_bwd) else None, + dace_include_path, + dace_include_blas, + dace_include_onnx, + } if p + ], + is_python_module=False, + ) + + torch_function = operator.attrgetter(f"dace_{compiled.sdfg.name}.{compiled.sdfg.name}")(torch.ops) + + return DaceTorchFunction(function=torch_function, compiled_sdfgs=compiled_sdfgs, ptr=ptrs) + + +def get_env_for_sdfg(compiled: CompiledSDFG): + """Create an environment for the given compiled SDFG. + + :param compiled: The compiled SDFG. + :return: The environment class for the SDFG. + """ + sdfg_build_path = os.path.abspath(compiled.sdfg.build_folder) + + class SDFGEnvironment: + """Environment for the SDFG.""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = [os.path.join(sdfg_build_path, "include")] + cmake_compile_flags = [] + cmake_link_flags = [] + cmake_files = [] + cmake_libraries = [os.path.join(sdfg_build_path, "build", platform_library_name(compiled.sdfg.name))] + state_fields = [] + dependencies = [] + headers = [] + init_code = "" + finalize_code = "" + + SDFGEnvironment.__name__ = compiled.sdfg.name + dace.library.environment(SDFGEnvironment) + return SDFGEnvironment + + +def indent_code(code: str) -> str: + """Indent the given code string properly. + + :param code: The code string to indent. + :return: The indented code string. + """ + stream = CodeIOStream() + stream.write(code) + return stream.getvalue() diff --git a/dace/libraries/torch/dispatchers/ctypes_module.py b/dace/libraries/torch/dispatchers/ctypes_module.py new file mode 100644 index 0000000000..fb73179bc1 --- /dev/null +++ b/dace/libraries/torch/dispatchers/ctypes_module.py @@ -0,0 +1,222 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +A torch python autograd function that calls the SDFG using ctypes. + +This can be as an alternative to the C++ registration for large neural nets to +get around the 64 parameter limit of torch's dispatcher. +""" +import copy +import itertools +from typing import List, Dict, Tuple + +from dace import data +import torch +from dace.codegen.compiled_sdfg import CompiledSDFG + +import dace +from dace.autodiff import BackwardResult +from dace.frontend.ml.onnx.importer import create_output_array +from dace.libraries.torch.dispatchers import DaceTorchFunction +from dace.libraries.torch.dispatchers.common import compile_and_init_sdfgs, \ + get_arglist + + +def init_remaining_parameters(module, fwd_arglist, input_names, output_names): + """Initialize remaining parameters that are not inputs or outputs. + + :param module: The DaCe module containing the weights. + :param fwd_arglist: Forward pass argument list. + :param input_names: Names of input tensors. + :param output_names: Names of output tensors. + :return: Dictionary of constant parameters. + :raises ValueError: If a parameter is neither an input/output nor a constant. + """ + # initialize all remaining parameters + remaining = set(fwd_arglist).difference(itertools.chain(input_names, output_names)) + constants = {} + for name in remaining: + # remaining arguments must be constant + if name not in module.dace_model.clean_weights: + raise ValueError(f"Cannot generate ctypes dispatcher: SDFG argument {name} is " + f"not an input or output of the PyTorch Module, and not a" + f" constant.") + constants[name] = module.dace_model.clean_weights[name] + if fwd_arglist[name].storage in dace.dtypes.GPU_STORAGES: + constants[name] = constants[name].cuda() + return constants + + +def callable_for_fwd_module(module: 'dace.frontend.ml.torch.DaceModule', forward_compiled: CompiledSDFG): + """Create a callable for forward pass execution. + + :param module: The DaCe module containing the model. + :param forward_compiled: Compiled SDFG for forward pass. + :return: Function that executes the forward pass. + """ + assert forward_compiled._initialized + + fwd_arglist = forward_compiled.sdfg.arglist() + + input_names, output_names = get_arglist(module) + + constants = init_remaining_parameters(module, fwd_arglist, input_names, output_names) + + def forward(*inputs): + kwargs = {} + + # set the inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # initialize the outputs + for name in output_names: + output_desc = forward_compiled.sdfg.arrays[name] + kwargs[name] = create_output_array( + {}, output_desc, use_torch=True, zeros=False + ) if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[name] + + # call the SDFG + return forward_compiled(**kwargs, **constants) + + return forward + + +def callable_for_bwd_module(module: 'dace.frontend.ml.torch.DaceModule', forward_compiled: CompiledSDFG, + backward_compiled: CompiledSDFG, backward_result: BackwardResult, + forwarded_arrays: Dict[str, data.Data]): + + assert forward_compiled._initialized + assert backward_compiled._initialized + + fwd_arglist = forward_compiled.sdfg.arglist() + + input_names, output_names = get_arglist(module) + + # arrays that we will forward to the backward pass using saved_for_backward + forwarded_io_names: List[str] = [name for name in forwarded_arrays if name in output_names or name in input_names] + + # non input/output arrays that we are forwarding + forwarded_non_io_names: List[str] = [ + name for name in forwarded_arrays if name not in output_names and name not in input_names + ] + + # for each gradient array that is required, this contains the: + # * name of the gradient + # * whether the array requires zero initialization + # * the descriptor for the array + gradient_descriptors: List[Tuple[str, bool, data.Data]] = [] + + for _, grad_name in backward_result.required_grad_names.items(): + zero_init = backward_result.zero_init.get(grad_name, True) + desc = backward_compiled.sdfg.arrays[grad_name] + + gradient_descriptors.append((grad_name, zero_init, desc)) + + outputs_with_forwarded_outputs: List[str] = copy.deepcopy(output_names) + outputs_with_forwarded_outputs.extend(n for n in forwarded_arrays if n not in input_names and n not in output_names) + + output_gradient_names: List[str] = [ + backward_result.given_grad_names[output] if output in backward_result.given_grad_names else None + for output in output_names + ] + input_gradient_names: List[str] = [ + backward_result.required_grad_names[input] if input in backward_result.required_grad_names else None + for input in input_names + ] + + constants = init_remaining_parameters(module, fwd_arglist, input_names, outputs_with_forwarded_outputs) + + class DifferentiableFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, *inputs): + kwargs = {} + + # set the inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # initialize the outputs + for name in outputs_with_forwarded_outputs: + output_desc = forward_compiled.sdfg.arrays[name] + kwargs[name] = create_output_array( + {}, output_desc, use_torch=True, zeros=True + ) if name not in module.dace_model.initialized_parameters else module.dace_model.initialized_parameters[ + name] + + # call the SDFG + outputs = forward_compiled(**kwargs, **constants) + + # save inputs/outputs for backward + ctx.save_for_backward(*(kwargs[name] for name in forwarded_io_names)) + + # save non- input/output values for backward + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return outputs + + @staticmethod + def backward(ctx, *grad_outputs): + kwargs = {} + + # recover saved values + saved = ctx.saved_tensors + for value_name, saved_value in zip(forwarded_io_names, saved): + kwargs[value_name] = saved_value + + for value_name in forwarded_non_io_names: + kwargs[value_name] = getattr(ctx, f"dace_saved_{value_name}") + + # create gradient buffers of inputs + for grad_name, zero_init, desc in gradient_descriptors: + kwargs[grad_name] = create_output_array({}, desc, use_torch=True, zeros=zero_init) + + # grab gradient buffers of outputs + for grad_name, grad_value in zip(output_gradient_names, grad_outputs): + kwargs[grad_name] = grad_value.contiguous() + + # call bwd sdfg + backward_compiled(**kwargs) + + # return grads + grads = tuple(None if name is None else kwargs[name] for name in input_gradient_names) + if len(grads) == 1: + return grads[0] + return grads + + return lambda *args: DifferentiableFunction.apply(*args) + + +def get_ctypes_dispatcher(module: 'dace.frontend.ml.torch.DaceModule', dummy_inputs) -> DaceTorchFunction: + """ + Get a torch callable for the module. This will compile the sdfg and create a + wrapper python callable that can be used with PyTorch. + + :param module: the module. + :param dummy_inputs: dummy inputs to initialize the model with. + :return: the callable function for the SDFG. + """ + + # build the SDFG + # set all states to not-sync + for state in module.sdfg.nodes(): + state.nosync = True + + if module.backward: + # TODO we could return the inferred symbols here + compiled, _, compiled_bwd, _ = compile_and_init_sdfgs(module, dummy_inputs) + + function = callable_for_bwd_module(module, compiled, compiled_bwd, module._ad_result, module._ad_inp_arrs) + compiled_sdfgs = [compiled, compiled_bwd] + else: + compiled, _ = compile_and_init_sdfgs(module, dummy_inputs) + function = callable_for_fwd_module(module, compiled) + compiled_sdfgs = [compiled] + + result = DaceTorchFunction( + function=function, + compiled_sdfgs=compiled_sdfgs, + # no pointers required for ctypes dispatcher + ptr=[]) + return result diff --git a/dace/libraries/torch/dlpack.py b/dace/libraries/torch/dlpack.py new file mode 100644 index 0000000000..614e1a5345 --- /dev/null +++ b/dace/libraries/torch/dlpack.py @@ -0,0 +1,188 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Interface for integrating with DLPack. + +Some of the following code is derived from the following resources: +https://github.com/dmlc/dlpack/blob/main/apps/from_numpy/main.py +https://github.com/vadimkantorov/pydlpack/blob/master/dlpack.py +""" + +import ctypes + +import dace +import torch +import torch.utils.dlpack +from dace import data, dtypes + + +class DLDeviceType(ctypes.c_int): + """DLPack device type enumeration.""" + kDLCPU = 1 + kDLGPU = 2 + kDLCPUPinned = 3 + kDLOpenCL = 4 + kDLVulkan = 7 + kDLMetal = 8 + kDLVPI = 9 + kDLROCM = 10 + kDLExtDev = 12 + + +class DLDataTypeCode(ctypes.c_uint8): + """DLPack data type code enumeration.""" + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLBfloat = 4 + + +class DLDataType(ctypes.Structure): + """DLPack data type structure.""" + _fields_ = [('type_code', DLDataTypeCode), ('bits', ctypes.c_uint8), ('lanes', ctypes.c_uint16)] + + +dace_to_dldtype_dict = { + dace.float32: DLDataType(DLDataTypeCode.kDLFloat, 32, 1), + dace.float64: DLDataType(DLDataTypeCode.kDLFloat, 64, 1), + dace.uint8: DLDataType(DLDataTypeCode.kDLUInt, 8, 1), + dace.uint16: DLDataType(DLDataTypeCode.kDLUInt, 16, 1), + dace.uint32: DLDataType(DLDataTypeCode.kDLUInt, 32, 1), + dace.uint64: DLDataType(DLDataTypeCode.kDLUInt, 64, 1), + dace.int8: DLDataType(DLDataTypeCode.kDLInt, 8, 1), + dace.int16: DLDataType(DLDataTypeCode.kDLInt, 16, 1), + dace.int32: DLDataType(DLDataTypeCode.kDLInt, 32, 1), + dace.int64: DLDataType(DLDataTypeCode.kDLInt, 64, 1), +} + + +class DLContext(ctypes.Structure): + """DLPack context structure for device information.""" + _fields_ = [('device_type', DLDeviceType), ('device_id', ctypes.c_int)] + + +class DLTensor(ctypes.Structure): + """DLPack tensor structure.""" + _fields_ = [('data', ctypes.c_void_p), ('ctx', DLContext), ('ndim', ctypes.c_int), ('dtype', DLDataType), + ('shape', ctypes.POINTER(ctypes.c_int64)), ('strides', ctypes.POINTER(ctypes.c_int64)), + ('byte_offset', ctypes.c_uint64)] + + +class DLManagedTensor(ctypes.Structure): + """DLPack managed tensor structure.""" + pass + + +DLManagedTensorHandle = ctypes.POINTER(DLManagedTensor) + +DeleterFunc = ctypes.CFUNCTYPE(None, DLManagedTensorHandle) + +DLManagedTensor._fields_ = [("dl_tensor", DLTensor), ("manager_ctx", ctypes.c_void_p), ("deleter", DeleterFunc)] + + +def make_manager_ctx(obj) -> ctypes.c_void_p: + """Create a manager context from a Python object. + + This function wraps a Python object in a ctypes void pointer and increments + its reference count to prevent garbage collection while in use by DLPack. + + :param obj: The Python object to create a context for. + :return: A ctypes void pointer to the object. + """ + pyobj = ctypes.py_object(obj) + void_p = ctypes.c_void_p.from_buffer(pyobj) + ctypes.pythonapi.Py_IncRef(pyobj) + return void_p + + +@DeleterFunc +def dl_managed_tensor_deleter(_dl_managed_tensor_handle) -> None: + """Deleter function for DLPack managed tensors. + + This is a no-op deleter because the underlying data is managed by DaCe + and will be freed when the SDFG state struct is deallocated. + + :param _dl_managed_tensor_handle: Handle to the managed tensor (unused). + """ + # Do nothing: the data is freed in the state struct + pass + + +class PyCapsule: + """Python capsule interface for DLPack integration.""" + New = ctypes.pythonapi.PyCapsule_New + New.restype = ctypes.py_object + New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p) + + SetContext = ctypes.pythonapi.PyCapsule_SetContext + SetContext.restype = ctypes.c_int + SetContext.argtypes = (ctypes.py_object, ctypes.c_void_p) + + GetContext = ctypes.pythonapi.PyCapsule_GetContext + GetContext.restype = ctypes.c_void_p + GetContext.argtypes = (ctypes.py_object, ) + + GetPointer = ctypes.pythonapi.PyCapsule_GetPointer + GetPointer.restype = ctypes.c_void_p + GetPointer.argtypes = (ctypes.py_object, ctypes.c_char_p) + + Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) + + SetDestructor = ctypes.pythonapi.PyCapsule_SetDestructor + SetDestructor.argtypes = (ctypes.py_object, Destructor) + SetDestructor.restype = ctypes.c_int + + +def array_to_torch_tensor(ptr: ctypes.c_void_p, desc: data.Array) -> torch.Tensor: + """Convert a DaCe array descriptor to a PyTorch tensor that points to the same data. + + This function performs zero-copy conversion using the DLPack protocol, + allowing PyTorch to access DaCe arrays without data duplication. + + :param ptr: The pointer to the memory of the array. + :param desc: The DaCe array descriptor containing shape, strides, and dtype information. + :return: A PyTorch tensor that shares memory with the DaCe array. + :raises ValueError: If the storage type or dtype is unsupported. + """ + + if desc.storage is dtypes.StorageType.GPU_Global: + device_type = DLDeviceType.kDLGPU + elif desc.storage in [dtypes.StorageType.CPU_Heap, dtypes.StorageType.Default]: + device_type = DLDeviceType.kDLCPU + else: + raise ValueError(f"Unsupported storage type {desc.storage}") + + context = DLContext(device_type=device_type, device_id=0) + + if desc.dtype not in dace_to_dldtype_dict: + raise ValueError(f"Unsupported dtype {desc.dtype}") + dtype = dace_to_dldtype_dict[desc.dtype] + + shape = (ctypes.c_int64 * len(desc.shape))() + for i, s in enumerate(desc.shape): + shape[i] = s + + strides = (ctypes.c_int64 * len(desc.shape))() + for i, s in enumerate(desc.strides): + strides[i] = s + + dltensor = DLTensor(data=ptr, + ctx=context, + ndim=len(desc.shape), + dtype=dtype, + shape=shape, + strides=strides, + byte_offset=0) + + c_obj = DLManagedTensor() + c_obj.dl_tensor = dltensor + c_obj.manager_ctx = ctypes.c_void_p(0) + c_obj.deleter = dl_managed_tensor_deleter + + # The capsule must be used in the same stack frame, otherwise it will be deallocated and the capsule will + # point to invalid data. + capsule = PyCapsule.New(ctypes.byref(c_obj), b"dltensor", None) + tensor: torch.Tensor = torch.utils.dlpack.from_dlpack(capsule) + + # Store the dltensor as an attribute of the tensor so that the tensor takes ownership + tensor._dace_dlpack = c_obj + return tensor diff --git a/dace/libraries/torch/environments/__init__.py b/dace/libraries/torch/environments/__init__.py new file mode 100644 index 0000000000..3af1d61e60 --- /dev/null +++ b/dace/libraries/torch/environments/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from .pytorch_env import PyTorch, PyTorchGPU diff --git a/dace/libraries/torch/environments/pytorch_env.py b/dace/libraries/torch/environments/pytorch_env.py new file mode 100644 index 0000000000..1d82beebdf --- /dev/null +++ b/dace/libraries/torch/environments/pytorch_env.py @@ -0,0 +1,100 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os + +try: + import torch.utils.cpp_extension +except ImportError as e: + raise ImportError("PyTorch is required for torch integration. Install with: pip install dace[ml]") from e + +import dace.library + +from dace.codegen.common import platform_library_name, get_gpu_backend + + +@dace.library.environment +class PyTorch: + """Environment used to build PyTorch C++ Operators.""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Get the required PyTorch library paths for linking. + + :return: List of library paths for PyTorch CPU libraries. + :raises RuntimeError: If a required library cannot be found. + """ + library_names = ["c10", "torch", "torch_cpu", "torch_python"] + library_paths = [] + + for name in library_names: + for path in torch.utils.cpp_extension.library_paths(): + path = os.path.join(path, platform_library_name(name)) + if os.path.isfile(path): + library_paths.append(path) + break + else: + raise RuntimeError(f"Couldn't locate shared library {name} in PyTorch library paths") + + return library_paths + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] + cmake_link_flags = [] + cmake_files = [] + state_fields = [] + dependencies = [] + + headers = [] + init_code = "" + finalize_code = "" + + +@dace.library.environment +class PyTorchGPU: + """Environment used to build PyTorch C++ Operators (with CUDA/HIP).""" + + cmake_minimum_version = None + cmake_packages = [] + cmake_variables = {} + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """ + Get the required PyTorch library paths for linking with GPU support. + + :return: List of library paths for PyTorch GPU libraries. + :raises RuntimeError: If a required library cannot be found. + """ + backend = get_gpu_backend() + if backend == 'hip': + library_names = ["c10", "torch", "torch_cpu", "torch_hip", "torch_python", "c10_hip"] + runtime_lib = "amdhip64" + else: + library_names = ["c10", "torch", "torch_cpu", "torch_cuda", "torch_python", "c10_cuda"] + runtime_lib = "cudart" + + library_paths = [] + for name in library_names: + for path in torch.utils.cpp_extension.library_paths(device_type=backend): + path = os.path.join(path, platform_library_name(name)) + if os.path.isfile(path): + library_paths.append(path) + break + else: + raise RuntimeError(f"Couldn't locate shared library {name} in PyTorch library paths") + + return library_paths + [runtime_lib] + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] + cmake_link_flags = [] + cmake_files = [] + state_fields = [] + dependencies = [] + + headers = [] + init_code = "" + finalize_code = "" diff --git a/dace/libraries/torch/torch.md b/dace/libraries/torch/torch.md new file mode 100644 index 0000000000..1a83857c82 --- /dev/null +++ b/dace/libraries/torch/torch.md @@ -0,0 +1,1254 @@ +Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +# DaCe PyTorch Integration Library - Design Document + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Architecture Overview](#2-architecture-overview) +3. [Directory Structure](#3-directory-structure) +4. [Core Components](#4-core-components) +5. [Dispatcher Strategies](#5-dispatcher-strategies) +6. [Integration Pipeline](#6-integration-pipeline) +7. [Zero-Copy Tensor Sharing](#7-zero-copy-tensor-sharing) +8. [Autograd Integration](#8-autograd-integration) + +--- + +## 1. Introduction + +### 1.1 Purpose + +The DaCe PyTorch Integration Library provides **bidirectional integration** between PyTorch's neural network framework and DaCe's high-performance SDFG execution engine. It enables: + +- **Optimizing PyTorch models** using DaCe's dataflow transformations +- **Accelerating training and inference** with optimized compiled code + +### 1.2 Current Capabilities + +- **Model Optimization**: Convert PyTorch `nn.Module` to optimized DaCe SDFGs +- **Automatic Differentiation**: Integration with PyTorch's autograd system +- **Dual Dispatch**: C++ extension (performance) or CTypes (flexibility) +- **Training Support**: Backward pass generation and gradient computation + +### 1.3 Integration Directions + +The library supports bidirectional data flow: + +**1. PyTorch → DaCe (Primary Direction)**: +```python +# Wrap PyTorch model for DaCe optimization +dace_module = DaceModule(pytorch_model, dummy_inputs, backward=True) + +# Use as drop-in replacement +output = dace_module(input_tensor) +loss.backward() # Autograd works! +``` + +**Workflow**: PyTorch Model → ONNX Export → DaCe SDFG → Compiled Code → PyTorch Operator + +**2. DaCe → PyTorch (Zero-Copy Access)**: +```python +# DaCe arrays accessible as PyTorch tensors (no copy) +torch_tensor = array_to_torch_tensor(ptr, dace_descriptor) +``` + +**Mechanism**: DLPack protocol for memory sharing + +### 1.4 Use Cases + +1. **Neural Network Optimization**: Speed up inference for production deployment +2. **Training Acceleration**: Optimize forward and backward passes for faster training +3. **Custom Operators**: Implement custom PyTorch operations with DaCe +4. **Research**: Experiment with dataflow-level optimizations on ML models +5. **Mixed Workflows**: Combine PyTorch layers with DaCe-optimized modules + +--- + +## 2. Architecture Overview + +### 2.1 High-Level System Diagram + +``` +┌───────────────────────────────────────────────────────────┐ +│ USER INTERFACE │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ DaceModule (pytorch_model, dummy_inputs, ...) │ │ +│ │ • Wraps PyTorch nn.Module │ │ +│ │ • Provides PyTorch-compatible interface │ │ +│ │ • Supports forward + backward passes │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼─────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ ONNX EXPORT PIPELINE │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ torch.onnx.export() │ │ +│ │ PyTorch Model → ONNX ModelProto │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼─────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ SDFG CONSTRUCTION │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ ONNXModel (onnx_proto) │ │ +│ │ ONNX → DaCe SDFG (Forward) │ │ +│ └────────────────────┬───────────────────────────────┘ │ +│ │ │ +│ ▼ (if backward=True) │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ BackwardPassGenerator │ │ +│ │ Forward SDFG → Backward SDFG │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ DISPATCHER SELECTION │ +│ ┌──────────────────┐ ┌──────────────────┐ │ +│ │ C++ Extension │ OR │ CTypes Module │ │ +│ ├──────────────────┤ ├──────────────────┤ │ +│ │ • Performance │ │ • No param limit │ │ +│ │ • Native PyTorch │ │ • Faster compile │ │ +│ │ • 64 param limit │ │ • Pure Python │ │ +│ └────────┬─────────┘ └────────┬─────────┘ │ +└───────────┼──────────────────────────────┼──────────────────┘ + │ │ + └──────────┬───────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ CODE GENERATION & COMPILATION │ +│ ┌────────────────────────────────────────────────────┐ │ +│ │ SDFG.compile() → Shared Library (.so) │ │ +│ │ C++ Codegen → PyTorch Operator Registration │ │ +│ │ State Initialization → Handle Creation │ │ +│ └──────────────────┬─────────────────────────────────┘ │ +└─────────────────────┼───────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ CALLABLE PYTORCH OPERATOR │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ dace_module(inputs) → outputs │ │ +│ │ • Zero-copy tensor access via DLPack │ │ +│ │ • Stateful execution via handles │ │ +│ └──────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 2.2 Component Interaction Flow + +``` +User Code: dace_module = DaceModule(model, dummy_inputs, backward=True) + ↓ +1. Store model and configuration + ↓ +User Code: output = dace_module(actual_input) # First call + ↓ +2. Detect function is None → Initialize SDFG + ↓ +3. Export to ONNX + a. torch.onnx.export(model, dummy_inputs) + b. Save parameters and model structure + ↓ +4. Import ONNX to DaCe + a. ONNXModel(onnx_proto) + b. Create forward SDFG + ↓ +5. Generate Backward (if backward=True) + a. Determine required gradients + b. BackwardPassGenerator.backward() + c. Create backward SDFG + d. Identify forwarded transients + ↓ +6. Compile SDFGs + a. forward_sdfg.compile() + b. backward_sdfg.compile() (if applicable) + c. Initialize with dummy inputs + ↓ +7. Select and Initialize Dispatcher + If compile_torch_extension: + a. Generate C++ code with autograd + b. Compile as PyTorch extension + c. Register as torch.ops.dace_{name}.{name} + Else: + a. Create Python autograd.Function + b. Wrap with CTypes calls + ↓ +8. Create Wrapper Function + a. Accept user inputs + b. Pass state handles as first args + c. Call compiled operator + d. Return outputs + ↓ +9. Execute and Return + a. Zero-copy tensor access via DLPack + b. Call native SDFG code + c. Return PyTorch tensors + ↓ +User Code: loss.backward() # Backward pass + ↓ +10. PyTorch calls backward function + a. Recover saved tensors from context + b. Allocate gradient buffers + c. Call backward SDFG + d. Return input gradients +``` + +--- + +## 3. Directory Structure + +### 3.1 File Organization + +``` +dace/libraries/torch/ +├── __init__.py # Library exports +│ └── Exports: PyTorch, PyTorchGPU environment classes +│ +├── dlpack.py # Zero-copy tensor sharing +│ ├── DLPack structure definitions +│ ├── Type conversion mappings +│ └── array_to_torch_tensor() - Main conversion function +│ +├── dispatchers/ # Dispatcher implementations +│ ├── __init__.py # Package exports +│ │ └── Exports: DaceTorchFunction, get_ctypes_dispatcher, register_and_compile_torch_extension +│ │ +│ ├── common.py # Shared utilities +│ │ ├── DaceTorchFunction dataclass +│ │ ├── get_arglist() +│ │ └── compile_and_init_sdfgs() +│ │ +│ ├── cpp_torch_extension.py # C++ extension generator +│ │ ├── Type conversion utilities +│ │ ├── C++ code generation for forward/backward +│ │ ├── Autograd function generation +│ │ ├── Tensor initialization +│ │ └── register_and_compile_torch_extension() +│ │ +│ └── ctypes_module.py # CTypes dispatcher +│ ├── init_remaining_parameters() +│ ├── callable_for_fwd_module() +│ ├── callable_for_bwd_module() +│ └── get_ctypes_dispatcher() +│ +└── environments/ # Build configuration + ├── __init__.py # Package exports (1 line) + └── pytorch_env.py # PyTorch environments + ├── PyTorch (CPU environment) + └── PyTorchGPU (GPU environment) + +``` + +### 3.2 Component Responsibilities + +| Component | Lines | Purpose | +|-----------|-------|---------| +| `cpp_torch_extension.py` | 699 | C++ code generation for PyTorch operators | +| `ctypes_module.py` | 222 | CTypes-based dispatcher for large models | +| `dlpack.py` | 188 | Zero-copy tensor sharing via DLPack | +| `common.py` | 112 | Shared dispatcher utilities | +| `pytorch_env.py` | 100 | CMake build configuration | +| `__init__.py` (dispatchers) | 22 | Dispatcher exports | +| `__init__.py` (main) | 23 | Library exports | + +**Note**: DaceModule is located at [dace/frontend/ml/torch/module.py](../../frontend/ml/torch/module.py) (581 lines). + +--- + +## 4. Core Components + +### 4.1 DaceModule: The Main Entry Point + +**Location**: [dace/frontend/ml/torch/module.py](../../frontend/ml/torch/module.py) + +#### Constructor Signature + +```python +class DaceModule: + def __init__( + self, + module: torch.nn.Module, + dummy_inputs: Tuple[torch.Tensor, ...], + cuda: bool = False, + backward: bool = False, + compile_torch_extension: bool = True, + auto_optimize: bool = True, + **onnx_kwargs + ): + """ + Wrap a PyTorch module for DaCe optimization. + + Args: + module: PyTorch module to optimize + dummy_inputs: Sample inputs for shape inference and tracing + cuda: Enable GPU execution + backward: Generate backward pass for training + compile_torch_extension: Use C++ extension (True) or CTypes (False) + auto_optimize: Apply DaCe optimizations + **onnx_kwargs: Additional arguments for torch.onnx.export() + """ +``` + +#### Key Methods + +- **`__call__(*inputs)`**: Execute the optimized module +- **`_initialize_sdfg(inputs)`**: Lazy compilation on first call +- **`_call_params()`**: Get model parameters as tensors + +#### Workflow Summary + +1. **Initialization**: Store model and configuration +2. **First Call**: Export to ONNX → Import to SDFG → Compile → Create dispatcher +3. **Subsequent Calls**: Direct execution via the compiled operator +4. **Backward**: Automatic execution via PyTorch autograd integration + +--- + +### 4.2 DLPack Bridge: Zero-Copy Tensor Sharing + +**Location**: [dlpack.py](dlpack.py) + +The DLPack bridge enables **zero-copy conversion** between DaCe arrays and PyTorch tensors. + +#### DLPack Structure Definitions + +**Type System**: +```python +class DLDeviceType(ctypes.c_int): + kDLCPU = 1 + kDLGPU = 2 + # ... other devices + +class DLDataTypeCode(ctypes.c_uint8): + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLBfloat = 4 + +class DLDataType(ctypes.Structure): + _fields_ = [('type_code', DLDataTypeCode), + ('bits', ctypes.c_uint8), + ('lanes', ctypes.c_uint16)] +``` + +**Tensor Representation**: +```python +class DLTensor(ctypes.Structure): + _fields_ = [ + ('data', ctypes.c_void_p), # Raw pointer + ('ctx', DLContext), # Device info + ('ndim', ctypes.c_int), # Number of dimensions + ('dtype', DLDataType), # Data type + ('shape', ctypes.POINTER(ctypes.c_int64)), # Shape array + ('strides', ctypes.POINTER(ctypes.c_int64)), # Strides array + ('byte_offset', ctypes.c_uint64) # Byte offset + ] +``` + +#### Zero-Copy Conversion + +**Function**: `array_to_torch_tensor(ptr, desc)` + +**Process**: +1. Map the DaCe storage type to DLDeviceType +2. Convert the DaCe dtype to DLDataType +3. Create shape and strides arrays +4. Build the DLTensor structure +5. Wrap in DLManagedTensor with a no-op deleter +6. Create a PyCapsule with name "dltensor" +7. Call `torch.utils.dlpack.from_dlpack(capsule)` to create the PyTorch tensor +8. Store the DLPack structure as a tensor attribute (prevents garbage collection) + +**Memory Ownership**: +- Data is owned by the DaCe SDFG state struct +- No-op deleter ensures that DaCe manages deallocation +- PyTorch tensor is a **view** into DaCe memory (zero-copy) + +**Type Mapping**: +```python +dace_to_dldtype_dict = { + dace.float32: DLDataType(kDLFloat, 32, 1), + dace.float64: DLDataType(kDLFloat, 64, 1), + dace.int32: DLDataType(kDLInt, 32, 1), + # ... complete mapping +} +``` + +--- + +### 4.3 Common Dispatcher Utilities + +**Location**: [dispatchers/common.py](dispatchers/common.py) + +#### DaceTorchFunction Dataclass + +```python +@dataclasses.dataclass +class DaceTorchFunction: + """Encapsulates a compiled DaCe module with PyTorch interface.""" + function: Callable # The callable (torch op or Python function) + compiled_sdfgs: List[CompiledSDFG] # [forward, backward] (or just [forward]) + ptr: List[torch.Tensor] # State handle pointers [fwd_handle, bwd_handle] +``` + +**Purpose**: Provides a uniform interface regardless of dispatcher choice (C++ or CTypes) + +#### compile_and_init_sdfgs() + +**Function Signature**: +```python +def compile_and_init_sdfgs( + module: DaceModule, + dummy_inputs: Tuple[torch.Tensor, ...] +) -> Union[ + Tuple[CompiledSDFG, torch.Tensor], # No backward + Tuple[CompiledSDFG, torch.Tensor, # With backward + CompiledSDFG, torch.Tensor] +]: +``` + +**Process**: +1. Compile the forward SDFG +2. Construct arguments from dummy inputs and parameters +3. Infer symbols from input shapes +4. Allocate forwarded transients (for backward pass) +5. Initialize the forward SDFG state +6. Extract the state handle as `torch.tensor([libhandle])` +7. If backward is enabled: + - Compile the backward SDFG + - Allocate gradient buffers + - Initialize the backward SDFG state + - Extract the backward handle +8. Return the compiled SDFGs and handles + +#### get_arglist() + +**Function**: +```python +def get_arglist(module: DaceModule) -> Tuple[List[str], List[str]]: + """Extracts input and output names with ONNX name cleaning.""" + inputs = [clean_onnx_name(name) for name in module.dace_model.inputs] + outputs = [clean_onnx_name(name) for name in module.dace_model.outputs] + return inputs, outputs +``` + +--- + +### 4.4 PyTorch Environment Configuration + +**Location**: [environments/pytorch_env.py](environments/pytorch_env.py) + +Defines the CMake build configuration for linking against PyTorch libraries. + +#### PyTorch Environment (CPU) + +```python +@dace.library.environment +class PyTorch: + """Environment for building PyTorch C++ operators (CPU).""" + + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Locate and return PyTorch library paths.""" + library_names = ["c10", "torch", "torch_cpu", "torch_python"] + # Search in torch.utils.cpp_extension.library_paths() + return library_paths + + cmake_compile_flags = ["-D_GLIBCXX_USE_CXX11_ABI=0"] # ABI compatibility +``` + +#### PyTorchGPU Environment (GPU) + +```python +@dace.library.environment +class PyTorchGPU: + """Environment for building PyTorch C++ operators (CUDA).""" + + cmake_includes = torch.utils.cpp_extension.include_paths() + + @staticmethod + def cmake_libraries(): + """Locate and return PyTorch CUDA library paths.""" + library_names = ["c10", "torch", "torch_cpu", "torch_cuda", + "torch_python", "c10_cuda"] + return library_paths + ["cudart"] +``` + +**Integration with DaCe**: +- Registered via the `@dace.library.environment` decorator +- DaCe's CMake generator uses these settings for linker configuration +- Ensures that compiled code can call the PyTorch C++ API + +--- + +## 5. Dispatcher Strategies + +### 5.1 Why Two Dispatchers? + +The library provides two dispatcher implementations to handle different use cases: + +| Feature | C++ Extension | CTypes Module | +|---------|--------------|---------------| +| **Performance** | High (native call) | Good (small overhead) | +| **Parameter Limit** | 64 parameters | Unlimited | +| **Compilation Time** | Slower (C++ compile) | Faster (no codegen) | +| **Registration** | `torch.ops.dace_name` | Python function | + +### 5.2 C++ PyTorch Extension + +**Location**: [dispatchers/cpp_torch_extension.py](dispatchers/cpp_torch_extension.py) + +#### Overview + +Generates C++ code that registers a custom PyTorch operator with native autograd support. + +#### Type Conversion Utilities + +**DaCe → PyTorch C++ Types**: +```python +_REPLACED_CTYPES = { + dace.int64: "int64_t", + dace.uint64: "uint64_t", + dace.float16: "at::Half" +} + +def torch_ctype(dtype: dace.typeclass) -> str: + """Convert DaCe type to PyTorch C++ type string.""" + if isinstance(dtype, dace.pointer): + return "int64_t" + elif dtype in _REPLACED_CTYPES: + return _REPLACED_CTYPES[dtype] + else: + return dtype.ctype # e.g., "float", "double" +``` + +**DaCe → PyTorch Tensor Dtype**: +```python +_TYPECLASS_TO_TORCH_DTYPE_STR = { + dt.bool: "kBool", + dt.int8: "kInt8", + dt.float32: "kFloat32", + dt.float64: "kFloat64", + # ... complete mapping +} +``` + +#### Tensor Initialization Code Generation + +**Function**: `tensor_init_for_desc()` + +**Purpose**: Generates C++ code to allocate PyTorch tensors + +**Approach**: +- Checks if tensor is a constant (from weights) +- If constant: embeds values as a C++ initializer list +- If output: allocates with `torch::zeros()` or `torch::empty()` +- Sets proper dtype, device (CPU/CUDA), and layout + +**Example Output**: +```cpp +Tensor output = torch::zeros( + {10, 256}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(torch::kCPU) + .layout(torch::kStrided) +); +``` + +#### Forward Function Code Generation + +**Generated Structure**: +```cpp +Tensor forward_function( + int64_t fwd_handle_ptr, + int64_t bwd_handle_ptr, // if backward + const Tensor& input_0_, + const Tensor& input_1_, + // ... more inputs +) { + // 1. Initialize outputs + Tensor output = torch::zeros({...}, torch::TensorOptions()...); + + // 2. Ensure inputs are contiguous + Tensor input_0 = input_0_.contiguous(); + + // 3. Extract pointers + float *input_0_ptr = reinterpret_cast(input_0.data_ptr()); + float *output_ptr = reinterpret_cast(output.data_ptr()); + + // 4. Call SDFG + MySDFGHandle_t handle = reinterpret_cast(fwd_handle_ptr); + __program_my_sdfg(handle, input_0_ptr, output_ptr); + + // 5. Return outputs + return output; // or std::make_tuple(...) for multiple +} +``` + +#### Autograd Function Code Generation + +**Generated Structure**: +```cpp +class MySDFGFunction : public torch::autograd::Function { +public: + static Tensor forward( + AutogradContext *ctx, + int64_t fwd_handle_ptr, + int64_t bwd_handle_ptr, + const Tensor& input_ + ) { + // Run forward pass + Tensor output = forward_function(fwd_handle_ptr, bwd_handle_ptr, input_); + + // Save for backward + ctx->save_for_backward({input_, output}); + + // Save non-I/O transients + ctx->saved_data["intermediate"] = intermediate_value; + + // Save backward handle + ctx->saved_data["bwd_handle"] = bwd_handle_ptr; + + return output; + } + + static tensor_list backward( + AutogradContext *ctx, + tensor_list grad_outputs + ) { + // 1. Recover saved tensors + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto intermediate = ctx->saved_data["intermediate"].toTensor(); + + // 2. Get backward handle + int64_t bwd_handle_ptr = ctx->saved_data["bwd_handle"].toInt(); + MySDFGBackwardHandle_t bwd_handle = + reinterpret_cast(bwd_handle_ptr); + + // 3. Allocate gradient buffers + Tensor grad_input = torch::zeros({...}); // or empty if zero_init=False + + // 4. Get output gradients + Tensor grad_output = grad_outputs[0].contiguous(); + + // 5. Call backward SDFG + __program_my_sdfg_backward( + bwd_handle, + grad_output.data_ptr(), + input.data_ptr(), + intermediate.data_ptr(), + grad_input.data_ptr() + ); + + // 6. Return gradients (None for non-differentiable) + return {Tensor(), Tensor(), grad_input}; // None for handles, grad for input + } +}; +``` + +#### Operator Registration + +**Generated Code**: +```cpp +// Register operator +TORCH_LIBRARY(dace_my_sdfg, m) { + m.def("my_sdfg", forward_function); +} + +// Register autograd if backward enabled +TORCH_LIBRARY_IMPL(dace_my_sdfg, Autograd, m) { + m.impl("my_sdfg", MySDFGFunction::apply); +} +``` + +#### Compilation Process + +**Function**: `register_and_compile_torch_extension()` + +**Steps**: +1. Generate the complete C++ source code +2. Write to a temporary file +3. Use `torch.utils.cpp_extension.load()` for JIT compilation +4. Link against: + - PyTorch libraries (from environment) + - Compiled SDFG shared library +5. Return the operator accessible via `torch.ops.dace_name.name` + +**Limitations**: +- PyTorch dispatcher has **64 parameter limit** +- Longer compilation time (~seconds) +- Requires C++ compiler + +--- + +### 5.3 CTypes Module + +**Location**: [dispatchers/ctypes_module.py](dispatchers/ctypes_module.py) + +#### Overview + +A pure Python dispatcher that calls compiled SDFGs via ctypes, avoiding C++ code generation. + +#### When to Use + +- Models with >64 parameters +- Rapid development/iteration +- Environments where C++ compilation is problematic +- Prototyping and debugging + +#### Forward-Only Callable + +**Function**: `callable_for_fwd_module()` + +**Generated Function**: +```python +def forward(*inputs): + kwargs = {} + + # Set inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # Initialize outputs + for name in output_names: + kwargs[name] = create_output_array( + {}, + forward_compiled.sdfg.arrays[name], + use_torch=True, + zeros=False + ) + + # Add constants + kwargs.update(constants) + + # Call SDFG (ctypes handles conversion) + return forward_compiled(**kwargs) +``` + +#### Forward+Backward Callable + +**Function**: `callable_for_bwd_module()` + +**Generated Autograd Function**: +```python +class DifferentiableFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *inputs): + kwargs = {} + + # Set inputs + for i, input_name in enumerate(input_names): + kwargs[input_name] = inputs[i].contiguous() + + # Initialize outputs + forwarded transients + for name in outputs_and_forwarded: + kwargs[name] = create_output_array(...) + + # Call forward SDFG + outputs = forward_compiled(**kwargs, **constants) + + # Save I/O tensors for backward + ctx.save_for_backward(*(kwargs[name] for name in forwarded_io_names)) + + # Save non-I/O transients as attributes + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return outputs + + @staticmethod + def backward(ctx, *grad_outputs): + kwargs = {} + + # Recover saved I/O tensors + saved = ctx.saved_tensors + for name, val in zip(forwarded_io_names, saved): + kwargs[name] = val + + # Recover non-I/O transients + for name in forwarded_non_io_names: + kwargs[name] = getattr(ctx, f"dace_saved_{name}") + + # Allocate gradient buffers + for grad_name, zero_init, desc in gradient_descriptors: + kwargs[grad_name] = create_output_array(..., zeros=zero_init) + + # Set output gradients from PyTorch + for grad_name, grad_val in zip(output_gradient_names, grad_outputs): + kwargs[grad_name] = grad_val.contiguous() + + # Call backward SDFG + backward_compiled(**kwargs) + + # Return input gradients (None for non-differentiable) + return tuple(kwargs.get(grad_name) for grad_name in input_gradient_names) + +return DifferentiableFunction.apply +``` + +#### Parameter Handling + +**Function**: `init_remaining_parameters()` + +**Purpose**: Extracts constant parameters (model weights) that are neither inputs nor outputs + +**Process**: +1. Identify parameters not in the input/output lists +2. Verify they exist in `module.dace_model.clean_weights` +3. Transfer to CUDA if needed +4. Return as a constants dictionary + +--- + +## 6. Integration Pipeline + +### 6.1 Complete Workflow + +``` +┌─────────────────────────────────────────────────────────┐ +│ Phase 1: Initialization │ +├─────────────────────────────────────────────────────────┤ +│ dace_module = DaceModule(model, dummy_inputs, ...) │ +│ │ +│ 1. Store PyTorch model reference │ +│ 2. Store configuration (cuda, backward, dispatcher) │ +│ 3. Set function = None (lazy compilation) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 2: First Forward Call │ +├─────────────────────────────────────────────────────────┤ +│ output = dace_module(actual_input) │ +│ │ +│ Detect function is None → Trigger _initialize_sdfg() │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 3: ONNX Export │ +├─────────────────────────────────────────────────────────┤ +│ 1. Call torch.onnx.export(model, dummy_inputs, ...) │ +│ 2. Save exported ONNX ModelProto │ +│ 3. Extract and save model parameters │ +│ 4. Remove initializers that overlap with inputs │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 4: ONNX → DaCe SDFG │ +├─────────────────────────────────────────────────────────┤ +│ 1. Create ONNXModel(onnx_proto) │ +│ - Import ONNX graph to SDFG │ +│ - Run shape inference │ +│ - Apply simplifications │ +│ 2. Store forward SDFG as module.sdfg │ +│ 3. Apply post_onnx_hooks (if any) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 5: Backward SDFG Generation (if backward=True) │ +├─────────────────────────────────────────────────────────┤ +│ 1. Determine required gradients: │ +│ - Model inputs (if not in clean_weights) │ +│ - Parameters with requires_grad=True │ +│ │ +│ 2. Call make_backward_function(): │ +│ a. Create backward SDFG │ +│ b. Initialize BackwardPassGenerator │ +│ c. Generate reverse operations │ +│ d. Identify forwarded transients │ +│ │ +│ 3. Modify forward SDFG: │ +│ - Make forwarded arrays non-transient (outputs) │ +│ - Convert scalars to size-1 arrays │ +│ │ +│ 4. Store: │ +│ - module.forward_sdfg │ +│ - module.backward_sdfg │ +│ - module._ad_result (BackwardResult) │ +│ - module._ad_inp_arrs (forwarded arrays) │ +│ │ +│ 5. Apply post_autodiff_hooks (if any) │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 6: SDFG Compilation │ +├─────────────────────────────────────────────────────────┤ +│ Call compile_and_init_sdfgs(module, dummy_inputs): │ +│ │ +│ 1. Compile forward SDFG → forward_compiled │ +│ 2. Construct arguments from dummy inputs + parameters │ +│ 3. Call _call_args() to infer symbols │ +│ 4. Allocate forwarded transients (if backward) │ +│ 5. Initialize forward SDFG state │ +│ 6. Extract state handle: fwd_handle=compiled._libhandle │ +│ │ +│ 7. If backward: │ +│ a. Compile backward SDFG → backward_compiled │ +│ b. Allocate gradient buffers │ +│ c. Initialize backward SDFG state │ +│ d. Extract backward handle │ +│ │ +│ 8. Apply post_compile_hooks (if any) │ +│ │ +│ 9. Return compiled SDFGs and handles │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 7: Dispatcher Generation │ +├─────────────────────────────────────────────────────────┤ +│ If compile_torch_extension: │ +│ ├─→ register_and_compile_torch_extension() │ +│ │ 1. Generate C++ code with autograd │ +│ │ 2. Compile as PyTorch extension │ +│ │ 3. Register operator │ +│ │ 4. Return torch.ops.dace_name.name │ +│ │ │ +│ Else: │ +│ └─→ get_ctypes_dispatcher() │ +│ 1. Create Python autograd.Function │ +│ 2. Wrap compiled SDFGs with ctypes calls │ +│ 3. Return callable │ +│ │ +│ Return DaceTorchFunction(function,compiled_sdfgs,ptrs)│ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 8: Wrapper Function Creation │ +├─────────────────────────────────────────────────────────┤ +│ Create forward() wrapper: │ +│ │ +│ def forward(*args): │ +│ return compiled_function.function( │ +│ *compiled_function.ptr, # State handles │ +│ *args, # User inputs │ +│ *parameters_to_pass) # Model params │ +│ │ +│ Store as module.function │ +└──────────────────────┬──────────────────────────────────┘ + ▼ +┌─────────────────────────────────────────────────────────┐ +│ Phase 9: Execution │ +├─────────────────────────────────────────────────────────┤ +│ Forward Pass: │ +│ 1. User calls dace_module(input) │ +│ 2. Wrapper extracts .contiguous() tensors │ +│ 3. Zero-copy access via DLPack (if needed) │ +│ 4. Call compiled SDFG with pointers │ +│ 5. Return PyTorch tensors │ +│ │ +│ Backward Pass (if backward=True): │ +│ 1. User calls loss.backward() │ +│ 2. PyTorch autograd calls backward function │ +│ 3. Recover saved tensors from context │ +│ 4. Allocate gradient buffers │ +│ 5. Call backward SDFG │ +│ 6. Return input gradients to PyTorch │ +└─────────────────────────────────────────────────────────┘ +``` + +### 6.2 Data Transformations + +**Input Transformation** (PyTorch → DaCe): +``` +torch.Tensor (user input) + ↓ .contiguous() +torch.Tensor (contiguous memory) + ↓ .data_ptr() (C++ extension) or direct pass (CTypes) +Raw pointer / PyTorch tensor + ↓ Passed to SDFG +SDFG operates on memory +``` + +**Output Transformation** (DaCe → PyTorch): +``` +Allocate torch.Tensor (zeros or empty) + ↓ Extract .data_ptr() +Raw pointer + ↓ Pass to SDFG as output parameter +SDFG fills memory + ↓ No copy needed +Return torch.Tensor (already owns memory) +``` + +**Constant Transformation**: +``` +PyTorch model parameters + ↓ Extract in ONNX export +ONNX initializers + ↓ Save as clean_weights +Embed in C++ (C++ extension) or pass as kwargs (CTypes) +``` + +--- + +## 7. Zero-Copy Tensor Sharing + +### 7.1 The DLPack Protocol + +**Purpose**: Industry-standard protocol for zero-copy tensor exchange between frameworks + +**Key Concept**: Shares memory pointers and metadata between frameworks without copying data + +### 7.2 DaCe → PyTorch Conversion + +**Function**: `array_to_torch_tensor(ptr, desc)` + +**Complete Process**: + +**Step 1: Device Mapping** +```python +if desc.storage == dtypes.StorageType.GPU_Global: + device_type = DLDeviceType.kDLGPU +elif desc.storage in [StorageType.CPU_Heap, StorageType.Default]: + device_type = DLDeviceType.kDLCPU +``` + +**Step 2: Type Conversion** +```python +dtype = dace_to_dldtype_dict[desc.dtype] +# e.g., dace.float32 → DLDataType(kDLFloat, 32, 1) +``` + +**Step 3: Shape and Strides** +```python +shape = (ctypes.c_int64 * len(desc.shape))(*desc.shape) +strides = (ctypes.c_int64 * len(desc.shape))(*desc.strides) +``` + +**Step 4: DLTensor Construction** +```python +dltensor = DLTensor( + data=ptr, # Raw pointer from DaCe + ctx=DLContext(device_type, device_id=0), + ndim=len(desc.shape), + dtype=dtype, + shape=shape, + strides=strides, + byte_offset=0 +) +``` + +**Step 5: Managed Tensor Wrapper** +```python +managed = DLManagedTensor( + dl_tensor=dltensor, + manager_ctx=None, + deleter=no_op_deleter # DaCe owns memory +) +``` + +**Step 6: PyCapsule Creation** +```python +capsule = PyCapsule.New( + ctypes.byref(managed), + b"dltensor", + None +) +``` + +**Step 7: PyTorch Conversion** +```python +tensor = torch.utils.dlpack.from_dlpack(capsule) +tensor._dace_dlpack = managed # Prevent GC +``` + +### 7.3 Memory Lifecycle + +**Ownership**: +- The DaCe SDFG state struct owns the memory +- PyTorch tensor is a **view** that shares the memory +- No-op deleter ensures that DaCe handles deallocation + +**Safety**: +- Keep the SDFG state alive as long as tensors exist +- State handles are stored as `torch.Tensor` objects (ref-counted) +- PyTorch's garbage collector won't free memory prematurely + +**Use Cases**: +- Return DaCe outputs as PyTorch tensors +- Access intermediate SDFG arrays from PyTorch +- Enable PyTorch operations on DaCe memory + +--- + +## 8. Autograd Integration + +### 8.1 Backward Pass Generation + +**Entry Point**: `make_backward_function()` (in `dace/autodiff/torch.py`) + +**Workflow**: + +**Step 1: Determine Required Gradients** +```python +required_grads = [] +for param_name in model.parameters(): + if param_name.requires_grad and param_name not in inputs: + required_grads.append(param_name) +``` + +**Step 2: Create Backward SDFG** +```python +generator = BackwardPassGenerator( + forward_sdfg, + backward_sdfg, + given_gradients=model_outputs, + required_gradients=model_inputs + required_params +) +backward_result = generator.backward() +``` + +**Step 3: Identify Forwarded Transients** +- Identifies values needed for gradient computation +- Example: For `y = x * w`, the backward pass needs both `x` and `w` +- These are marked as non-transient (outputs) in the forward SDFG + +**Step 4: Modify Forward SDFG** +- Makes forwarded arrays non-transient +- Converts scalar outputs to size-1 arrays +- Ensures proper storage types + +### 8.2 C++ Extension Autograd + +**Forward Method**: +```cpp +static Tensor forward(AutogradContext *ctx, int64_t fwd_handle, + int64_t bwd_handle, Tensor input) { + // Execute forward SDFG + Tensor output = forward_function(fwd_handle, bwd_handle, input); + + // Save I/O tensors + ctx->save_for_backward({input, output}); + + // Save non-I/O transients (not saved by PyTorch) + ctx->saved_data["intermediate"] = intermediate_value; + + // Save backward handle + ctx->saved_data["bwd_handle"] = bwd_handle; + + return output; +} +``` + +**Backward Method**: +```cpp +static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + // 1. Recover saved tensors + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto intermediate = ctx->saved_data["intermediate"].toTensor(); + + // 2. Get handles + int64_t bwd_handle = ctx->saved_data["bwd_handle"].toInt(); + + // 3. Allocate gradient buffers + Tensor grad_input = torch::zeros({...}); // If zero_init=True + // OR + Tensor grad_input = torch::empty({...}); // If zero_init=False + + // 4. Get output gradients + Tensor grad_output = grad_outputs[0].contiguous(); + + // 5. Call backward SDFG + __program_backward( + bwd_handle, + grad_output.data_ptr(), + input.data_ptr(), + intermediate.data_ptr(), + grad_input.data_ptr() + ); + + // 6. Return gradients (None for handles, grads for inputs) + return {Tensor(), Tensor(), grad_input}; +} +``` + +### 8.3 CTypes Autograd + +**Forward Method**: +```python +@staticmethod +def forward(ctx, *inputs): + kwargs = {} + + # Set inputs + for i, name in enumerate(input_names): + kwargs[name] = inputs[i].contiguous() + + # Allocate outputs + forwarded transients + for name in all_output_names: + kwargs[name] = create_output_array(...) + + # Call forward SDFG + forward_compiled(**kwargs, **constants) + + # Save I/O for backward + ctx.save_for_backward(*(kwargs[n] for n in forwarded_io_names)) + + # Save non-I/O transients as attributes + for name in forwarded_non_io_names: + setattr(ctx, f"dace_saved_{name}", kwargs[name]) + + return tuple(kwargs[n] for n in model_output_names) +``` + +**Backward Method**: +```python +@staticmethod +def backward(ctx, *grad_outputs): + kwargs = {} + + # Recover I/O tensors + saved = ctx.saved_tensors + for name, val in zip(forwarded_io_names, saved): + kwargs[name] = val + + # Recover non-I/O transients + for name in forwarded_non_io_names: + kwargs[name] = getattr(ctx, f"dace_saved_{name}") + + # Allocate gradient buffers + for grad_name, zero_init in gradient_specs: + kwargs[grad_name] = create_output_array(..., zeros=zero_init) + + # Set output gradients + for grad_name, grad_val in zip(out_grad_names, grad_outputs): + kwargs[grad_name] = grad_val.contiguous() + + # Call backward SDFG + backward_compiled(**kwargs) + + # Return input gradients + return tuple(kwargs.get(g) for g in input_grad_names) +``` + +### 8.4 Gradient Accumulation + +**BackwardResult Structure**: +```python +required_grad_names = { + "input_0": "grad_input_0", + "param_weight": "grad_param_weight" +} + +given_grad_names = { + "output": "grad_output" +} + +zero_init = { + "grad_input_0": True, # Initialize to zero + "grad_param_weight": False # Don't initialize (accumulate) +} +``` + +**Usage**: +- `zero_init=True`: First gradient computation (allocate and initialize to zeros) +- `zero_init=False`: Accumulate into existing buffer (for gradient accumulation) + +--- diff --git a/dace/ml/__init__.py b/dace/ml/__init__.py new file mode 100644 index 0000000000..2e8dc8c341 --- /dev/null +++ b/dace/ml/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +# Import PyTorch frontend +try: + from dace.frontend.ml.torch import DaceModule, module +except ImportError: + DaceModule = None + module = None + +# Import ONNX frontend +try: + from dace.frontend.ml.onnx import ONNXModel +except ImportError: + ONNXModel = None + +__all__ = ['DaceModule', 'module', 'ONNXModel'] diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 44a085603d..bda9d8707e 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2110,7 +2110,7 @@ def _add_symbols(sdfg: SDFG, desc: dt.Data): if isinstance(v, dt.Data): _add_symbols(sdfg, v) for sym in desc.free_symbols: - if sym.name not in sdfg.symbols: + if sym.name not in sdfg.symbols and sym.name not in sdfg.arg_names: sdfg.add_symbol(sym.name, sym.dtype) # Add the data descriptor to the SDFG and all symbols that are not yet known. diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index ad13aefd51..84660da9a6 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1099,6 +1099,8 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg """ in_edges = state.in_edges(view) + # We should ignore empty synchronization edges + in_edges = [e for e in in_edges if not e.data.is_empty()] out_edges = state.out_edges(view) # Invalid case: No data to view @@ -2689,3 +2691,87 @@ def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[flo assert isinstance(scalar_name, str) assert isinstance(scalar_val, (float, int, str)) _specialize_scalar_impl(sdfg, sdfg, scalar_name, scalar_val) + + +def in_edge_with_name(node: nd.Node, state: SDFGState, name: str) -> MultiConnectorEdge: + """ + Find the edge that connects to input connector `name` on `node`. + + :param node: the node. + :param state: the state. + :param name: the input connector name. + :return: the edge that connects to connector `name`. + """ + cands = list(state.in_edges_by_connector(node, name)) + if len(cands) != 1: + raise ValueError("Expected to find exactly one edge with name '{}', found {}".format(name, len(cands))) + return cands[0] + + +def out_edge_with_name(node: nd.Node, state: SDFGState, name: str) -> MultiConnectorEdge: + """ + Find the edge that connects to output connector `name` on `node`. + + :param node: the node. + :param state: the state. + :param name: the output connector name. + :return: the edge that connects to connector `name`. + """ + cands = list(state.out_edges_by_connector(node, name)) + if len(cands) != 1: + raise ValueError("Expected to find exactly one edge with name '{}', found {}".format(name, len(cands))) + return cands[0] + + +def in_desc_with_name(node: nd.Node, state: SDFGState, sdfg: SDFG, name: str) -> dt.Data: + """ + Find the descriptor of the data that connects to input connector `name`. + + :param node: the node. + :param state: the state. + :param sdfg: the sdfg. + :param name: the input connector name. + :return: the descriptor of the data that connects to connector `name`. + """ + return sdfg.arrays[in_edge_with_name(node, state, name).data.data] + + +def out_desc_with_name(node: nd.Node, state: SDFGState, sdfg: SDFG, name: str) -> dt.Data: + """ + Find the descriptor of the data that connects to output connector `name`. + + :param node: the node. + :param state: the state. + :param sdfg: the sdfg. + :param name: the output connector name. + :return: the descriptor of the data that connects to connector `name`. + """ + return sdfg.arrays[out_edge_with_name(node, state, name).data.data] + + +def expand_nodes(sdfg: SDFG, predicate: Callable[[nd.Node], bool]): + """ + Recursively expand library nodes in the SDFG using a given predicate. + + :param sdfg: the sdfg to expand nodes on. + :param predicate: a predicate that will be called to check if a node should be expanded. + """ + if sdfg is None: + return + states = list(sdfg.states()) + while len(states) > 0: + state = states.pop() + expanded_something = False + for node in list(state.nodes()): + if isinstance(node, nd.NestedSDFG): + expand_nodes(node.sdfg, predicate=predicate) + elif isinstance(node, nd.LibraryNode): + if predicate(node): + impl_name = node.expand(sdfg, state) + if config.Config.get_bool('debugprint'): + print("Automatically expanded library node \"{}\" with implementation \"{}\".".format( + str(node), impl_name)) + expanded_something = True + + if expanded_something: + states.append(state) diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 561673dbeb..ee8fb57e44 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -590,6 +590,7 @@ def auto_optimize(sdfg: SDFG, # Simplification and loop parallelization transformed = True sdfg.apply_transformations_repeated(TrivialMapElimination, validate=validate, validate_all=validate_all) + while transformed: sdfg.simplify(validate=False, validate_all=validate_all) l2ms = sdfg.apply_transformations_repeated((LoopToMap, RefineNestedAccess), @@ -612,7 +613,6 @@ def auto_optimize(sdfg: SDFG, # fuse subgraphs greedily sdfg.simplify() sdfg.reset_cfg_list() - greedy_fuse(sdfg, device=device, validate_all=validate_all) # fuse stencils greedily diff --git a/dace/transformation/onnx/__init__.py b/dace/transformation/onnx/__init__.py new file mode 100644 index 0000000000..651cd8eed4 --- /dev/null +++ b/dace/transformation/onnx/__init__.py @@ -0,0 +1,10 @@ +try: + from .constant_folding import ConstantFolding + from .parameter_to_transient import parameter_to_transient + from .optimize import expand_onnx_nodes, auto_optimize_onnx +except ImportError: + # ONNX transformations not available + ConstantFolding = None + parameter_to_transient = None + expand_onnx_nodes = None + auto_optimize_onnx = None diff --git a/dace/transformation/onnx/constant_folding.py b/dace/transformation/onnx/constant_folding.py new file mode 100644 index 0000000000..53b54daf5c --- /dev/null +++ b/dace/transformation/onnx/constant_folding.py @@ -0,0 +1,158 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, TYPE_CHECKING + +import numpy as np + +import dace +import torch +from dace import config +from dace.properties import make_properties +from dace.transformation import transformation +from dace.sdfg import nodes as nd +from dace.sdfg import utils as sdutil + +import dace.libraries.onnx as donnx +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.onnx.nodes.onnx_op import ONNXOp + +if TYPE_CHECKING: + from dace.frontend.ml.onnx import ONNXModel + +# blocklist of nondeterministic ops +# yapf: disable +NONDETERMINISTIC_OPS = {'ONNXDropout', + 'ONNXGradient', + 'ONNXGraphCall', + 'ONNXIf', + 'ONNXLoop', + 'ONNXMomentum', + 'ONNXMultinomial', + 'ONNXRandomNormal', + 'ONNXRandomNormalLike', + 'ONNXRandomUniform', + 'ONNXRandomUniformLike', + 'ONNXSVMClassifier', + 'ONNXSVMRegressor', + 'ONNXScan', + 'ONNXTreeEnsembleClassifier', + 'ONNXTreeEnsembleRegressor'} +# yapf: enable + + +@make_properties +class ConstantFolding(transformation.SingleStateTransformation): + """ Remove nodes where all inputs are known and replace them with constant nodes by precomputing the output. + """ + # pattern matching only checks that the type of the node matches, + onnx_node = transformation.PatternNode(ONNXOp) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.onnx_node)] + + @staticmethod + def is_constant(sdfg: dace.SDFG, state: dace.SDFGState, node) -> bool: + if len(state.in_edges(node)) > 0: + return False + + # the ONNX importer adds a _parent_onnx_model attribute to the sdfg + if isinstance(node, nd.AccessNode) and node.data in sdfg._parent_onnx_model.clean_weights: + return True + + return False + + def can_be_applied(self, + graph: dace.sdfg.graph.OrderedMultiDiConnectorGraph, + expr_index: int, + sdfg, + permissive: bool = False): + + node = self.onnx_node + + # SDFG must be imported from an ONNXModel + if not hasattr(sdfg, "_parent_onnx_model"): + return False + + if not 'ONNX' + node.schema.name not in NONDETERMINISTIC_OPS: + return False + + if isinstance(node, donnx.ONNXShape): + assert len(graph.in_edges(node)) == 1 + shape_in_edge = graph.in_edges(node)[0] + assert shape_in_edge.dst_conn == "data" + shape_desc = sdfg.arrays[shape_in_edge.src.data] + try: + np.array(shape_desc.shape, np.int64) + except Exception: + # this happens if the shape is symbolic, for example + return False + + return True + + # all inputs are constant + for edge in graph.in_edges(node): + if not ConstantFolding.is_constant(sdfg, graph, edge.src): + return False + + return True + + @classmethod + def match_to_str(cls, graph): + node: ONNXOp = cls.onnx_node + return "Precompute outputs of {}".format(node) + + def apply(self, state: dace.SDFGState, sdfg: dace.SDFG): + parent: "ONNXModel" = sdfg._parent_onnx_model + node = self.onnx_node + if config.Config.get_bool('debugprint'): + print(f"Applying constant folding: {node} in {state}") + + if isinstance(node, donnx.ONNXShape): + # if we have a shape node, replace it with a constant + assert len(state.in_edges(node)) == 1 + shape_in_edge = state.in_edges(node)[0] + assert shape_in_edge.dst_conn == "data" + shape_desc = sdfg.arrays[shape_in_edge.src.data] + + constant_name = sdfg.temp_data_name() + clean_constant_name = clean_onnx_name(constant_name) + sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ), dace.int64) + + assert constant_name not in parent.clean_weights + parent.weights[constant_name] = torch.from_numpy(np.array(shape_desc.shape, np.int64)) + + assert len(state.out_edges(node)) == 1 + output_edge = state.out_edges(node)[0] + access_shape = state.add_access(clean_constant_name) + state.add_edge(access_shape, None, output_edge.dst, output_edge.dst_conn, + sdfg.make_array_memlet(clean_constant_name)) + + # remove all now useless nodes with a reverse BFS + remove_node_and_computation(sdfg, state, node) + + +def remove_node_and_computation(sdfg: dace.SDFG, state: dace.SDFGState, node: nd.Node, connector: Optional[str] = None): + """ Remove a node and the parent nodes that compute this node, if the outputs are not used elsewhere. + + :param sdfg: the sdfg containing the node. + :param state: the state containing the node. + :param node: the node to remove + :param connector: if not None, the computation of the connector of + ``node`` will be removed, but not ``node`` itself. + """ + if connector is not None: + if connector not in node.in_connectors: + return + node.remove_in_connector(connector) + edges = state.in_edges_by_connector(node, connector) + for e in edges: + state.remove_edge(e) + else: + edges = state.out_edges(node) + for e in edges: + state.remove_edge(e) + + # remove dangling nodes, this can happen with non-transients + for node, parent in sdfg.all_nodes_recursive(): + if (isinstance(node, nd.AccessNode) and parent.in_degree(node) + parent.out_degree(node) == 0): + parent.remove_node(node) diff --git a/dace/transformation/onnx/optimize.py b/dace/transformation/onnx/optimize.py new file mode 100644 index 0000000000..dc266aa4c0 --- /dev/null +++ b/dace/transformation/onnx/optimize.py @@ -0,0 +1,65 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Optional, Callable + +import dace +from dace import config, nodes as nd +from dace.libraries import blas +from dace.sdfg.utils import expand_nodes +from dace.transformation import dataflow +from dace.transformation.auto.auto_optimize import set_fast_implementations +from dace.transformation.dataflow import CopyToMap + + +def expand_onnx_nodes(sdfg: dace.SDFG, predicate: Optional[Callable[[nd.Node], bool]] = None): + """ Recursively expand all onnx library nodes in the SDFG, resulting in an SDFG that can be optimized by + dace transformations. Will also specialize dace matmuls. + + :param sdfg: the sdfg to expand nodes on. + :param predicate: a predicate that will be called to check if a node should be expanded. + """ + + try: + from dace.libraries.onnx.nodes.onnx_op import ONNXOp # avoid import loop + except ImportError: + raise ImportError("expand_onnx_nodes requires ONNX. Install with: pip install dace[ml]") + + if predicate is None: + new_predicate = lambda n: isinstance(n, (ONNXOp, blas.MatMul)) + else: + new_predicate = lambda n: predicate(n) and isinstance(n, (ONNXOp, blas.MatMul)) + + expand_nodes(sdfg, new_predicate) + + +def auto_optimize_onnx(sdfg: dace.SDFG, cuda, simplify=False, fold_constants=True): + """ Automatically optimize ``sdfg``. + + :param sdfg: the sdfg to optimize (inplace). + :param cuda: whether to optimize for cuda. + :param simplify: whether to apply simplification transformations to the sdfg after optimization. + :param fold_constants: whether to apply constant folding. + """ + + try: + from dace.transformation.onnx import ConstantFolding # avoid import loop + except ImportError: + raise ImportError("auto_optimize_onnx requires ONNX. Install with: pip install dace[ml]") + + if config.Config.get_bool('debugprint'): + print("Applying automatic optimizations") + if fold_constants: + if config.Config.get_bool('debugprint'): + print("Applying constant folding") + sdfg.apply_transformations_repeated([ConstantFolding, dataflow.RedundantSecondArray], validate_all=True) + if config.Config.get_bool('debugprint'): + print("Expanding ONNX nodes") + expand_onnx_nodes(sdfg) + if config.Config.get_bool('debugprint'): + print("Setting fast implementations") + set_fast_implementations(sdfg, dace.DeviceType.GPU if cuda else dace.DeviceType.CPU) + if simplify: + if config.Config.get_bool('debugprint'): + print("Applying simplification transforms") + sdfg.simplify() + if cuda: + sdfg.apply_transformations_once_everywhere(CopyToMap) diff --git a/dace/transformation/onnx/parameter_to_transient.py b/dace/transformation/onnx/parameter_to_transient.py new file mode 100644 index 0000000000..39988b83b4 --- /dev/null +++ b/dace/transformation/onnx/parameter_to_transient.py @@ -0,0 +1,83 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import operator + +import dace +from dace import config, dtypes, nodes + +from dace.libraries.onnx.converters import clean_onnx_name +from dace.libraries.torch import dlpack + + +def parameter_to_transient(dace_module: 'dace.frontend.ml.torch', parameter_path: str): + """ Convert the dace array for pytorch parameter found at parameter_path to a persistently allocated transient. + + :param dace_module: the module containing the weight to transform. + :param weight_path: the dotted path to the weight + """ + + if config.Config.get_bool('debugprint'): + print(f"Converting parameter {parameter_path} to a transient") + + pt_weight_name = parameter_path + pt_tensor = operator.attrgetter(pt_weight_name)(dace_module.model) + array_name = clean_onnx_name(pt_weight_name) + dace_module.dace_model.inputs.remove(parameter_path) + + # the the access node for this array of this array + cands = [(node, parent) for (node, parent) in dace_module.sdfg.all_nodes_recursive() + if isinstance(node, nodes.AccessNode) and node.data == array_name] + + if len(cands) == 0: + if config.Config.get_bool('debugprint'): + print(f"Warning: Could not find access node with name '{array_name}', skipping parameter to transient") + return + + if len(cands) != 1: + raise ValueError("parameter_to_transient does not work when the target array has multiple AccessNodes") + + if array_name not in dace_module.sdfg.arrays: + raise ValueError(f"Could not find parameter {array_name} in sdfg.") + + if dace_module.sdfg.arrays[array_name].storage is dtypes.StorageType.GPU_Global: + dace_module.sdfg.arrays[array_name].transient = True + dace_module.sdfg.arrays[array_name].lifetime = dtypes.AllocationLifetime.Persistent + gpu_array_name = array_name + else: + + # find the GPU transient of this array + state: dace.SDFGState + cand, state = cands[0] + if state.out_degree(cand) != 1: + raise ValueError(f"expected one out edge coming out of {cand}, found {state.out_degree(cand)}") + _, _, dst_node, _, _ = state.out_edges(cand)[0] + if (not isinstance(dst_node, nodes.AccessNode) + or dace_module.sdfg.arrays[dst_node.data].storage is not dtypes.StorageType.GPU_Global): + raise ValueError(f"parameter_to_transient only works for arrays that are copied to GPU_Global arrays," + f" but array {array_name} was connected to {dst_node}") + + gpu_array_name = dst_node.data + + # since it is parsable, proceed with the transformation + dace_module.sdfg.arrays[gpu_array_name].transient = True + dace_module.sdfg.arrays[gpu_array_name].lifetime = dtypes.AllocationLifetime.Persistent + + # remove the CPU node + state.remove_node(cand) + del dace_module.sdfg[array_name] + + def post_compile_hook(compiled_sdfg): + + struct = compiled_sdfg.get_state_struct() + + param_sdfg = compiled_sdfg.sdfg + struct_entry_name = f'__{param_sdfg.sdfg_id}_{gpu_array_name}' + + if not hasattr(struct, struct_entry_name): + raise ValueError(f"Could not parse parameter {gpu_array_name} from state_struct.") + + ptr = getattr(struct, struct_entry_name) + # copy the data into the torch parameter tensor + torch_tensor = dlpack.array_to_torch_tensor(ptr, param_sdfg.arrays[gpu_array_name]) + torch_tensor[:] = pt_tensor + + dace_module.post_compile_hooks["init_" + pt_weight_name] = post_compile_hook diff --git a/dace/transformation/onnx/replacement.py b/dace/transformation/onnx/replacement.py new file mode 100644 index 0000000000..ce51fdd14c --- /dev/null +++ b/dace/transformation/onnx/replacement.py @@ -0,0 +1,159 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" General class for pattern replacement transformations. """ +import abc +import dace +from dace import registry, nodes, data as dt +from dace.transformation import transformation, helpers as xfh +from typing import Any, Dict, List, Optional, Tuple, Union +from dace.sdfg import utils as sdutil +from dace.libraries.onnx import nodes as onnx_op +from dace.sdfg import graph as gr + + +def make_onnx_path(*path_nodes: nodes.Node) -> gr.OrderedDiGraph: + result = gr.OrderedDiGraph() + + # First add the nodes in order, so that they can be accessed + path_nodes = [transformation.PatternNode(n) for n in path_nodes] + result.add_nodes_from(path_nodes) + + # Then make a path and add access nodes as necessary + last_node = None + for node in path_nodes: + if last_node is not None: + result.add_edge(last_node, node) + last_node = node + + return result + + +def add_connecting_access_nodes(graph: gr.OrderedDiGraph): + edges_to_remove = [] + outputs = {} + for pnode in graph.nodes(): + if issubclass(pnode.node, (nodes.LibraryNode, nodes.NestedSDFG)): + if any(issubclass(e.dst.node, (nodes.LibraryNode, nodes.NestedSDFG)) for e in graph.out_edges(pnode)): + # Make new output node that everyone will link from + new_node = transformation.PatternNode(nodes.AccessNode) + graph.add_node(new_node) + graph.add_edge(pnode, new_node) + outputs[pnode] = new_node + + for e in graph.edges(): + if (issubclass(e.src.node, (nodes.LibraryNode, nodes.NestedSDFG)) + and issubclass(e.dst.node, (nodes.LibraryNode, nodes.NestedSDFG))): + # Direct path between two library nodes means that there is at least + # another access node in between + if e.src in outputs: + graph.add_edge(outputs[e.src], e.dst) + edges_to_remove.append(e) + else: + raise ValueError('Found directly connected library nodes with source not designated as output') + for e in edges_to_remove: + graph.remove_edge(e) + + +def onnx_constant_or_none(sdfg: dace.SDFG, node_or_name: Union[nodes.AccessNode, str]) -> Optional[Any]: + name = node_or_name if isinstance(node_or_name, str) else node_or_name.data + if name not in sdfg._parent_onnx_model.clean_weights: + return None + cten = sdfg._parent_onnx_model.clean_weights[name] + return cten.item() if cten.numel() == 1 else cten.tolist() + + +class ReplacementTransformation(transformation.SingleStateTransformation, abc.ABC): + + @classmethod + @abc.abstractmethod + def pattern(cls) -> gr.OrderedDiGraph[nodes.Node, dace.Memlet]: + """ Returns a pattern to match as a directed graph. """ + raise NotImplementedError + + @abc.abstractmethod + def replacement(self, subgraph: List[nodes.Node], sdfg: dace.SDFG, + state: dace.SDFGState) -> Tuple[nodes.Node, Dict[str, Tuple[nodes.Node, Union[str, dt.Data]]]]: + """ + Defines replacement behavior for the transformation. This method returns + a node (which could also be a nested SDFG if a subgraph should be + returned), accompanied by instructions for reconnecting the surrounding + nodes and creating new data (arrays). + :param subgraph: The list of nodes in the matched state with the same + IDs as the pattern subgraph. + :param sdfg: The SDFG in which to perform the replacement. + :param state: The state in which the subgraph was found. + :return: A 2-tuple of (new node, mapping), where the latter maps a + connector name on the new node to either a pair of + (old node, old connector) to redirect from, or + (None, data descriptor) if a new one shall be created. + """ + raise NotImplementedError + + @classmethod + def expressions(cls): + if hasattr(cls, '_pattern'): + return [cls._pattern] + result = cls.pattern() + add_connecting_access_nodes(result) + + # Set subgraph as class property + cls._pattern = result + # Set pattern nodes as class properties + for i, node in enumerate(result.nodes()): + setattr(cls, f'_pnode{i}', node) + return [result] + + def can_be_applied(self, graph: Union[dace.SDFG, dace.SDFGState], candidate: Dict[transformation.PatternNode, int], + expr_index: int, sdfg: dace.SDFG, simplify: bool) -> bool: + # All internal nodes must not be global (non-transient) or reused + # anywhere else + subgraph = gr.SubgraphView(graph, [graph.node(id) for id in candidate.values()]) + for node in subgraph.nodes(): + # Check for internal nodes + if node in subgraph.source_nodes() or node in subgraph.sink_nodes(): + continue + if not isinstance(node, nodes.AccessNode): + continue + if not node.desc(sdfg).transient: + return False + other_data_nodes_with_same_name = [ + n for s in sdfg.nodes() for n in s.nodes() + if isinstance(n, nodes.AccessNode) and n.data == node.data and n not in subgraph.nodes() + ] + if len(other_data_nodes_with_same_name) > 0: + return False + return True + + def apply(self, sdfg: dace.SDFG) -> nodes.Node: + state: dace.SDFGState = sdfg.node(self.state_id) + matcher = self.expressions()[0] + subgraph = [state.node(self.subgraph[n]) for n in matcher.nodes()] + new_node, reconnection = self.replacement(subgraph, sdfg, state) + + # Remap edges and add new arrays + for new_conn, (node, old_conn) in reconnection.items(): + # Make new array + if node is None: + desc = old_conn + name = sdfg.add_datadesc('_' + new_conn, desc, find_new_name=True) + node = state.add_access(name) + if new_conn in new_node.in_connectors: + state.add_edge(node, None, new_node, new_conn, dace.Memlet(name)) + elif new_conn in new_node.out_connectors: + state.add_edge(new_node, new_conn, node, None, dace.Memlet(name)) + continue + # END of new array + + if new_conn in new_node.in_connectors: + e = next(state.in_edges_by_connector(node, old_conn)) + xfh.redirect_edge(state, e, new_dst=new_node, new_dst_conn=new_conn) + elif new_conn in new_node.out_connectors: + e = next(state.out_edges_by_connector(node, old_conn)) + xfh.redirect_edge(state, e, new_src=new_node, new_src_conn=new_conn) + + # Remove subgraph nodes that are not connected from outside + sgview = gr.SubgraphView(state, subgraph) + state.remove_nodes_from( + [n for n in subgraph if isinstance(n, nodes.CodeNode) or state.degree(n) == sgview.degree(n)]) + # Remove orphan nodes + state.remove_nodes_from([n for n in state.nodes() if isinstance(n, nodes.AccessNode) and state.degree(n) == 0]) + return new_node diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 2514d8412d..693a4a7777 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -145,6 +145,10 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer continue edge = state.in_edges(node)[0] + # Edge must not be empty + if edge.data.is_empty(): + continue + # Edge must not be WCR if edge.data.wcr is not None: candidates.remove(candidate) @@ -169,11 +173,16 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer if state.out_degree(edge.src) > 1: candidates.remove(candidate) continue - # If inputs to tasklets are not arrays, skip + for tinput in state.in_edges(edge.src): + # If inputs to tasklets are not arrays, skip if not isinstance(tinput.src, nodes.AccessNode): candidates.remove(candidate) break + # If edge memlet is empty, skip + if tinput.data.is_empty(): + candidates.remove(candidate) + break if isinstance(sdfg.arrays[tinput.src.data], dt.Stream): candidates.remove(candidate) break diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index 0749b03b82..2a47221116 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -50,7 +50,7 @@ def _lift_returns(self, sdfg: SDFG) -> int: region, and not the entire SDFG. :param sdfg: The SDFG in which to lift returns - :returns: The number of return blocks lifted + :return: The number of return blocks lifted """ returns_lifted = 0 for nd in sdfg.nodes(): @@ -220,7 +220,7 @@ def _lift_unstructured(self, sdfg: SDFG) -> int: cycles represent unstructured control flow. :param sdfg: The SDFG in which to lift unstructured control flow - :returns: The number of unstructured control flow blocks lifted + :return: The number of unstructured control flow blocks lifted """ lifted = 0 for cfg in sdfg.all_control_flow_regions(): diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 93a6517720..caaa367fdb 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -256,6 +256,11 @@ def can_be_applied(sdfg, subgraph) -> bool: if data_name in coverages[child_entry][0]: children_coverage = subsets.union(children_coverage, coverages[child_entry][0][data_name]) + # TODO: Is there a better fix for this? + if children_coverage is None: + # no coverage for this data_name in children + # this is not supported + return False # extend mapping map_parameter -> coverage # by the previous mapping diff --git a/pytest.ini b/pytest.ini index b0aa6e9b8f..2649e6c4c5 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,8 +15,12 @@ markers = datainstrument: Test uses data instrumentation (select with '-m datainstrument') hptt: Test requires the HPTT library (select with '-m "hptt') long: Test runs for a long time and is skipped in CI (select with '-m "long"') + torch: Test for the PyTorch/ONNX frontend (select with '-m "torch"') + autodiff: Test for automatic differentiation (select with '-m "autodiff"') + onnx: Test for the ONNX frontend (select with '-m "onnx"') sequential: Test must be run sequentially (select with '-m "sequential"') python_files = + test_*.py *_test.py *_cudatest.py addopts = --ignore=dace/external --color=yes diff --git a/setup.py b/setup.py index aaf120c8f0..bac530c79e 100644 --- a/setup.py +++ b/setup.py @@ -78,9 +78,22 @@ 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, extras_require={ + 'ml': ['onnx', 'torch', 'onnxsim', 'onnxscript', 'onnxruntime', 'protobuf', 'ninja'], 'testing': [ + 'coverage', + 'pytest-cov', + 'scipy', + 'absl-py', + 'opt_einsum', + 'pymlir', + 'click', + 'ipykernel', + 'nbconvert', + 'pytest-timeout', + ], + 'ml-testing': [ 'coverage', 'pytest-cov', 'scipy', 'absl-py', 'opt_einsum', 'pymlir', 'click', 'ipykernel', 'nbconvert', - 'pytest-timeout' + 'pytest-timeout', 'transformers == 4.50', 'jax <= 0.6.2', 'efficientnet_pytorch' ], 'docs': ['jinja2<3.2.0', 'sphinx-autodoc-typehints', 'sphinx-rtd-theme>=0.5.1'], 'linting': ['pre-commit==4.1.0', 'yapf==0.43.0'], diff --git a/tests/.gitignore b/tests/.gitignore index 748014c0fd..6553fb299d 100644 --- a/tests/.gitignore +++ b/tests/.gitignore @@ -2,3 +2,6 @@ *.from_serialized *.serialized *.txt + +# Ignore downloaded files for tests +data/ diff --git a/tests/autodiff/test_multi_state.py b/tests/autodiff/test_multi_state.py new file mode 100644 index 0000000000..db23109d63 --- /dev/null +++ b/tests/autodiff/test_multi_state.py @@ -0,0 +1,304 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +from dace import SDFG, InterstateEdge, Memlet +from test_single_state import SDFGBackwardRunner, run_correctness + + +@pytest.mark.autodiff +@run_correctness +def test_two_state_add_mul(): + """ + Test a two-state SDFG: + - State 1: Z = X + Y (element-wise addition) + - State 2: S = sum(Z * Z) (element-wise multiplication then sum) + """ + + sdfg = SDFG("two_state_add_mul") + + sdfg.add_array("X", [3, 3], dace.float32) + sdfg.add_array("Y", [3, 3], dace.float32) + sdfg.add_array("Z", [3, 3], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + state1 = sdfg.add_state("state1") + X_read = state1.add_access("X") + Y_read = state1.add_access("Y") + Z_write = state1.add_access("Z") + + map_entry, map_exit = state1.add_map("add_map", dict(i="0:3", j="0:3")) + + tasklet_add = state1.add_tasklet("add", {"x", "y"}, {"z"}, "z = x + y") + + state1.add_memlet_path(X_read, map_entry, tasklet_add, dst_conn="x", memlet=Memlet("X[i, j]")) + state1.add_memlet_path(Y_read, map_entry, tasklet_add, dst_conn="y", memlet=Memlet("Y[i, j]")) + state1.add_memlet_path(tasklet_add, map_exit, Z_write, src_conn="z", memlet=Memlet("Z[i, j]")) + + state2 = sdfg.add_state("state2") + Z_read = state2.add_access("Z") + S_write = state2.add_access("S") + + map_entry2, map_exit2 = state2.add_map("mul_map", dict(i="0:3", j="0:3")) + + tasklet_mul = state2.add_tasklet("mul", {"z"}, {"s"}, "s = z * z") + + state2.add_memlet_path(Z_read, map_entry2, tasklet_mul, dst_conn="z", memlet=Memlet("Z[i, j]")) + state2.add_memlet_path(tasklet_mul, + map_exit2, + S_write, + src_conn="s", + memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + sdfg.add_edge(state1, state2, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X, Y): + Z = X + Y + S = (Z * Z).sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_conditional_simple(): + """ + Test a Python program with a simple conditional in the forward pass: + if X[0, 0] > 0.5: + Y = X * 2 + else: + Y = X * 3 + S = sum(Y) + """ + + @dace.program + def conditional_program(X: dace.float32[3, 3], Y: dace.float32[3, 3], S: dace.float32[1]): + if X[0, 0] > 0.5: + Y[:] = X * 2.0 + else: + Y[:] = X * 3.0 + S[0] = np.sum(Y) + + sdfg = conditional_program.to_sdfg(simplify=True) + + # PyTorch reference implementation + def torch_func(*, X): + Y = torch.where(X[0, 0] > 0.5, X * 2.0, X * 3.0) + S = Y.sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S", simplify=False), + torch_func, + dict(X=np.random.rand(3, 3).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_for_loop(): + """ + Test a simple for loop similar to jacobi_1d, but simplified: + for i in range(3): + A = A + B + S = sum(A) + """ + + @dace.program + def for_loop_program(A: dace.float32[10], B: dace.float32[10]): + for i in range(3): + A[:] = A + B + return np.sum(A) + + sdfg = for_loop_program.to_sdfg() + + # PyTorch reference implementation + def torch_func(*, A, B): + A_result = A.clone() + for i in range(3): + A_result = A_result + B + S = A_result.sum() + S.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + A=np.random.rand(10).astype(np.float32), + B=np.random.rand(10).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_diamond_pattern_conditional(): + """ + Test an SDFG with a diamond pattern control flow using GOTOs. + + Structure: + state1: Y = X * 2 + if X[0] > 0.5: + goto state3 + else: + goto state2 + state2: Y = Y + 1 + state3: S = sum(Y) + + This creates a diamond pattern where both paths can reach state3. + """ + + sdfg = SDFG("irreducible_cf") + + # Add arrays + sdfg.add_array("X", [5], dace.float32) + sdfg.add_array("Y", [5], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + # State 1: Y = X * 2 + state1 = sdfg.add_state("state1") + X_read1 = state1.add_access("X") + Y_write1 = state1.add_access("Y") + + map_entry1, map_exit1 = state1.add_map("mul_map", dict(i="0:5")) + tasklet1 = state1.add_tasklet("mul", {"x"}, {"y"}, "y = x * 2.0") + + state1.add_memlet_path(X_read1, map_entry1, tasklet1, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet1, map_exit1, Y_write1, src_conn="y", memlet=Memlet("Y[i]")) + + # State 2: Y = Y + 1 + state2 = sdfg.add_state("state2") + Y_read2 = state2.add_access("Y") + Y_write2 = state2.add_access("Y") + + map_entry2, map_exit2 = state2.add_map("add_map", dict(i="0:5")) + tasklet2 = state2.add_tasklet("add", {"y_in"}, {"y_out"}, "y_out = y_in + 1.0") + + state2.add_memlet_path(Y_read2, map_entry2, tasklet2, dst_conn="y_in", memlet=Memlet("Y[i]")) + state2.add_memlet_path(tasklet2, map_exit2, Y_write2, src_conn="y_out", memlet=Memlet("Y[i]")) + + # State 3: S = sum(Y) + state3 = sdfg.add_state("state3") + Y_read3 = state3.add_access("Y") + S_write3 = state3.add_access("S") + + map_entry3, map_exit3 = state3.add_map("sum_map", dict(i="0:5")) + tasklet3 = state3.add_tasklet("sum", {"y"}, {"s"}, "s = y") + + state3.add_memlet_path(Y_read3, map_entry3, tasklet3, dst_conn="y", memlet=Memlet("Y[i]")) + state3.add_memlet_path(tasklet3, map_exit3, S_write3, src_conn="s", memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + # Create conditional edges (irreducible control flow) + # Add condition: if X[0] > 0.5 goto state3, else goto state2 + sdfg.add_edge(state1, state3, InterstateEdge(condition="X[0] > 0.5")) + sdfg.add_edge(state1, state2, InterstateEdge(condition="X[0] <= 0.5")) + sdfg.add_edge(state2, state3, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X): + Y = X * 2.0 + Y = torch.where(X[0] > 0.5, Y, Y + 1.0) + S = Y.sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S", simplify=False), + torch_func, + dict(X=np.random.rand(5).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_multi_output_state(): + """ + Test a two-state SDFG where the first state produces multiple outputs: + State 1: Y = X * 2, Z = X + 1 + State 2: S = sum(Y * Z) + """ + + # Build SDFG using API + sdfg = SDFG("multi_output_state") + + # Add arrays + sdfg.add_array("X", [5], dace.float32) + sdfg.add_array("Y", [5], dace.float32, transient=False) + sdfg.add_array("Z", [5], dace.float32, transient=False) + sdfg.add_array("S", [1], dace.float32) + + # State 1: Compute Y and Z + state1 = sdfg.add_state("state1") + X_read1 = state1.add_access("X") + Y_write1 = state1.add_access("Y") + Z_write1 = state1.add_access("Z") + + map_entry1, map_exit1 = state1.add_map("compute_map", dict(i="0:5")) + tasklet_y = state1.add_tasklet("compute_y", {"x"}, {"y"}, "y = x * 2.0") + tasklet_z = state1.add_tasklet("compute_z", {"x"}, {"z"}, "z = x + 1.0") + + state1.add_memlet_path(X_read1, map_entry1, tasklet_y, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet_y, map_exit1, Y_write1, src_conn="y", memlet=Memlet("Y[i]")) + + X_read2 = state1.add_access("X") + state1.add_memlet_path(X_read2, map_entry1, tasklet_z, dst_conn="x", memlet=Memlet("X[i]")) + state1.add_memlet_path(tasklet_z, map_exit1, Z_write1, src_conn="z", memlet=Memlet("Z[i]")) + + # State 2: Multiply and sum + state2 = sdfg.add_state("state2") + Y_read2 = state2.add_access("Y") + Z_read2 = state2.add_access("Z") + S_write2 = state2.add_access("S") + + map_entry2, map_exit2 = state2.add_map("mul_sum_map", dict(i="0:5")) + tasklet_mul = state2.add_tasklet("mul", {"y", "z"}, {"s"}, "s = y * z") + + state2.add_memlet_path(Y_read2, map_entry2, tasklet_mul, dst_conn="y", memlet=Memlet("Y[i]")) + state2.add_memlet_path(Z_read2, map_entry2, tasklet_mul, dst_conn="z", memlet=Memlet("Z[i]")) + state2.add_memlet_path(tasklet_mul, + map_exit2, + S_write2, + src_conn="s", + memlet=Memlet("S[0]", wcr="lambda a, b: a + b")) + + # Connect states + sdfg.add_edge(state1, state2, InterstateEdge()) + + # PyTorch reference implementation + def torch_func(*, X): + Y = X * 2.0 + Z = X + 1.0 + S = (Y * Z).sum() + S.backward() + return dict(gradient_X=X.grad) + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict(X=np.random.rand(5).astype(np.float32)), + ) + + +if __name__ == "__main__": + test_two_state_add_mul() + test_conditional_simple() + test_for_loop() + test_diamond_pattern_conditional() + test_multi_output_state() diff --git a/tests/autodiff/test_nested.py b/tests/autodiff/test_nested.py new file mode 100644 index 0000000000..427efce8c2 --- /dev/null +++ b/tests/autodiff/test_nested.py @@ -0,0 +1,174 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +from dace.transformation.interstate import StateFusion + +import dace.libraries.onnx as donnx +from test_single_state import SDFGBackwardRunner, run_correctness + + +@dace.program +def inner_sdfg(Z: dace.float32[3, 3], W: dace.float32[3, 3]): + W[:] = dace.elementwise(lambda x: log(x), Z) + + +@dace.program +def inner_sdfg_with_intermediate(Z: dace.float32[3, 3], W: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Z) + W[:] = dace.elementwise(lambda x: log(x), intermediate) + + +@dace.program +def middle_sqrt(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg(intermediate, W) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_nested(): + sdfg = middle_sqrt.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + W = torch.log(inter) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@dace.program +def middle_sqrt_with_intermediate(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg_with_intermediate(intermediate, W) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_nested_forwarding(): + sdfg = middle_sqrt_with_intermediate.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + inter2 = torch.sqrt(inter) + W = torch.log(inter2) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@dace.program +def middle_sqrt_no_sum(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + inner_sdfg_with_intermediate(intermediate, W) + return W + + +@dace.program +def outer_sqrt_with_intermediate(Y: dace.float32[3, 3]): + intermediate = dace.define_local([3, 3], dace.float32) + W = dace.define_local([3, 3], dace.float32) + intermediate[:] = dace.elementwise(lambda x: sqrt(x), Y) + W[:] = middle_sqrt_no_sum(intermediate) + Z = np.sum(W) + return Z + + +@pytest.mark.autodiff +@run_correctness +def test_triple_nested_forwarding(): + sdfg = outer_sqrt_with_intermediate.to_sdfg(simplify=True) + + def torch_func(*, Y): + inter = torch.sqrt(Y) + inter2 = torch.sqrt(inter) + inter3 = torch.sqrt(inter2) + W = torch.log(inter3) + Z = torch.sum(W) + Z.backward() + return dict(gradient_Y=Y.grad) + + return (SDFGBackwardRunner(sdfg, "__return", + simplify=False), torch_func, dict(Y=np.random.rand(3, 3).astype(np.float32))) + + +@pytest.mark.autodiff +@run_correctness +def test_view_forwarding(): + # Prepare the inner sdfg + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def add_reshape_grad_test_nested(inp1: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2], + result: dace.float64): + reshaped = dace.define_local([3, 3], dace.float64) + added = inp1 + 1 + donnx.ONNXReshape(data=added, shape=target_shape, reshaped=reshaped) + Z = reshaped * bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + result[:] = np.sum(Zl) + + sdfg = add_reshape_grad_test_nested.to_sdfg(simplify=True) + + sdfg.expand_library_nodes() + del sdfg.arrays["target_shape"] + + donnx.default_implementation = old_default + + # Prepare the outer SDFG + @dace.program + def inner_view_forwarding(inp1: dace.float64[9], bias: dace.float64[3]): + result = dace.define_local_scalar(dace.float64) + # target shape gets removed by the pure reshape expansion + sdfg(inp1=inp1, bias=bias, result=result) + return result + 1 + + # This generates a FunctionCallRegion in the current frontned + # We need to simplify. + outer_sdfg = inner_view_forwarding.to_sdfg(simplify=True) + outer_sdfg.apply_transformations_repeated([StateFusion]) + + def torch_func(*, inp1, bias): + reshaped = torch.reshape(inp1 + 1, [3, 3]) + + Z = reshaped * bias + Zl = torch.log(Z + 1) + S = Zl.sum() + 1 + + S.backward() + return dict(gradient_inp1=inp1.grad, gradient_bias=bias.grad) + + return (SDFGBackwardRunner(outer_sdfg, "__return", simplify=False), torch_func, + dict(inp1=np.random.rand(9).astype(np.float64), bias=np.random.rand(3).astype(np.float64))) + + +if __name__ == "__main__": + test_nested() + test_nested_forwarding() + test_triple_nested_forwarding() + test_view_forwarding() diff --git a/tests/autodiff/test_single_state.py b/tests/autodiff/test_single_state.py new file mode 100644 index 0000000000..aeeeb0eb1c --- /dev/null +++ b/tests/autodiff/test_single_state.py @@ -0,0 +1,635 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch + +import dace +import dace.sdfg.nodes as nd +from dace.transformation.interstate import StateFusion + +import dace.libraries.onnx as donnx +from dace.autodiff import add_backward_pass + + +################################## +# Testing utilities +def run_correctness(func): + + def test_correctness(): + runner, pytorch_func, inputs = func() + sdfg_dict = {name: arr.copy() for name, arr in inputs.items()} + torch_dict = {name: torch.tensor(arr.copy(), requires_grad=True) for name, arr in inputs.items()} + + sdfg_results = runner.run(**sdfg_dict) + torch_results = pytorch_func(**torch_dict) + + for k, v in torch_results.items(): + v = v.detach().numpy() + diff = np.linalg.norm(sdfg_results[k] - v) / np.prod(v.shape) + assert diff < 1e-5, f"Gradient mismatch for '{k}': normalized difference {diff:.2e} exceeds tolerance 1e-5" + + return test_correctness + + +class SDFGBackwardRunner: + + def __init__(self, sdfg, target, simplify=True): + if simplify: + sdfg.simplify() + self.sdfg: dace.SDFG = sdfg + self.target = target + + # Collect all non-transient float arrays from all states as required gradients + required_grads = [] + seen_names = set() + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nd.AccessNode): + arr = node.desc(sdfg) + if (arr.dtype in [dace.float32, dace.float64] and not arr.transient + and node.data not in seen_names): + required_grads.append(node) + seen_names.add(node.data) + + add_backward_pass(sdfg=self.sdfg, outputs=[self.target], inputs=required_grads, simplify=simplify) + + def run(self, **inputs): + + # Zero out all arrays + intermediate_arrs = {} + gradient_target = "gradient_" + self.target + + for name, arr in self.sdfg.arrays.items(): + # Skip gradient target, dunder names, inputs, and transients + if (name == gradient_target or name.startswith("__") or name in inputs or arr.transient): + continue + + dtype = getattr(np, arr.dtype.to_string()) + intermediate_arrs[name] = np.zeros(arr.shape, dtype=dtype) + + inputs.update(intermediate_arrs) + inputs["gradient_" + self.target] = np.ones((1, ), + dtype=getattr(np, self.sdfg.arrays[self.target].dtype.to_string())) + + self.sdfg(**inputs) + + results = {name: arr for name, arr in inputs.items()} + return results + + +################################## +# Tests +@pytest.mark.autodiff +@run_correctness +def test_gemm(): + + def torch_gemm(*, X, Y): + Z = X @ Y + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_gemm( + X: dace.float32[5, 4], + Y: dace.float32[4, 3], + Z: dace.float32[5, 3], + S: dace.float32[1], + ): + + Z[:] = X @ Y + + @dace.map(_[0:5, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = z + + sdfg = dace_gemm.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_gemm, + dict( + X=np.random.rand(5, 4).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_sum(): + + def torch_sum(*, X, Y): + Z = X + Y + Z = Z * Z + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_sum( + X: dace.float32[3, 3], + Y: dace.float32[3, 3], + Z: dace.float32[3, 3], + S: dace.float32[1], + ): + + Z[:] = X + Y + + @dace.map(_[0:3, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = z * z + + sdfg = dace_sum.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_sum, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_complex_tasklet(): + + def torch_sum(*, X, Y): + Z = X + Y + Z = Z * Z + S = Z.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad) + + @dace.program + def dace_sum_complex( + X: dace.float32[3, 3], + Y: dace.float32[3, 3], + Z: dace.float32[3, 3], + S: dace.float32[1], + ): + + Z[:] = X + Y + + @dace.map(_[0:3, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + + z1 = z + 1 + log(3) # random expr + z2 = z - 1 * (2 / 2) + # hello world 1, 2, 3 + s = z1 * z2 + + sdfg = dace_sum_complex.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_sum, + dict( + X=np.random.rand(3, 3).astype(np.float32), + Y=np.random.rand(3, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_only_reuse(): + + def torch_func(*, A): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(A + 1) + + C = tmp_a * tmp_b + + C.backward() + return dict(gradient_A=A.grad) + + @dace.program + def tasklets_only_reuse(A: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + a << A[0] + a_out >> tmp_b + + a_out = log(a + 1) + + with dace.tasklet: + a << tmp_a + b << tmp_b + c >> C[0] + c = a * b + + sdfg = tasklets_only_reuse.to_sdfg(simplify=False) + sdfg.simplify() + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict(A=np.random.rand(1).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_multioutput(): + + def torch_func(*, A, B): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(B + 1) + + C = tmp_a * tmp_b * B + + C.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + @dace.program + def tasklets_multioutput(A: dace.float32[1], B: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + tmp_d = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + b << B[0] + b_out >> tmp_b + d_out >> tmp_d + + b_out = log(b + 1) + d_out = b + + with dace.tasklet: + a << tmp_a + b << tmp_b + d << tmp_d + c >> C[0] + c = a * b * d + + sdfg = tasklets_multioutput.to_sdfg(simplify=False) + sdfg.simplify() + + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict( + A=np.random.rand(1).astype(np.float32), + B=np.random.rand(1).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_tasklets_only(): + + def torch_func(*, A, B): + tmp_a = torch.sqrt(A) + tmp_b = torch.log(B + 1) + + C = tmp_a * tmp_b + + C.backward() + return dict(gradient_A=A.grad, gradient_B=B.grad) + + @dace.program + def tasklets_only(A: dace.float32[1], B: dace.float32[1], C: dace.float32[1]): + tmp_a = dace.define_local_scalar(dace.float32) + tmp_b = dace.define_local_scalar(dace.float32) + + with dace.tasklet: + a << A[0] + a_out >> tmp_a + + a_out = sqrt(a) + + with dace.tasklet: + a << B[0] + a_out >> tmp_b + + a_out = log(a + 1) + + with dace.tasklet: + a << tmp_a + b << tmp_b + c >> C[0] + c = a * b + + sdfg = tasklets_only.to_sdfg(simplify=False) + sdfg.simplify() + + return ( + SDFGBackwardRunner(sdfg, "C"), + torch_func, + dict( + A=np.random.rand(1).astype(np.float32), + B=np.random.rand(1).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_add_mmul_transpose_log(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = W * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def add_mmul_transpose_log( + X: dace.float32[4, 5], + Y: dace.float32[4, 3], + W: dace.float32[4, 3], + S: dace.float32[1], + ): + + Xt = np.transpose(X) + YW = W * Y + Z = Xt @ YW + + @dace.map(_[0:5, 0:3]) + def summap(i, j): + s >> S(1, lambda x, y: x + y)[0] + z << Z[i, j] + s = log(z + 1) + + sdfg = add_mmul_transpose_log.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "S"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float32), + W=np.random.rand(4, 3).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_node_1_axis_and_none_axis(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = torch.sum(W, dim=0) * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def reduce_node_1_axis_and_none_axis(X: dace.float32[4, 5], Y: dace.float32[4, 3], W: dace.float32[7, 4, 3]): + + Xt = np.transpose(X) + YW = np.sum(W, axis=0) * Y + Z = Xt @ YW + + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = reduce_node_1_axis_and_none_axis.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float32), + W=np.random.rand(7, 4, 3).astype(np.float32), + Y=np.random.rand(4, 3).astype(np.float32), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_max_simple(): + + def torch_func(*, W): + + Z = torch.max(W, dim=1) + S = Z.values.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def reduce_max_simple(W: dace.float32[4, 5]): + + Z = np.max(W, axis=1) + S = np.sum(Z) + return S + + sdfg = reduce_max_simple.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict(W=np.random.rand(4, 5).astype(np.float32)), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reduce_max_node_1_axis(): + + def torch_func(*, X, Y, W): + + Xt = X.T + YW = torch.min(W, dim=0).values * Y + Z = Xt @ YW + Zl = torch.log(Z + 1) + + S = Zl.sum() + S.backward() + return dict(gradient_X=X.grad, gradient_Y=Y.grad, gradient_W=W.grad) + + @dace.program + def dace_func(X: dace.float64[4, 5], Y: dace.float64[4, 3], W: dace.float64[7, 4, 3]): + + Xt = np.transpose(X) + YW = np.min(W, axis=0) * Y + Z = Xt @ YW + + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = dace_func.to_sdfg() + + return ( + SDFGBackwardRunner(sdfg, "__return"), + torch_func, + dict( + X=np.random.rand(4, 5).astype(np.float64), + W=np.random.rand(7, 4, 3).astype(np.float64), + Y=np.random.rand(4, 3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape(): + + @dace.program + def single_state_reshape(inp: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + Z = reshaped + bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape.to_sdfg(simplify=False) + + sdfg.apply_transformations_repeated([StateFusion]) + + def torch_func(*, inp, bias): + reshaped = torch.reshape(inp, [3, 3]) + + Z = reshaped + bias + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp=inp.grad, gradient_bias=bias.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict( + inp=np.random.rand(9).astype(np.float64), + bias=np.random.rand(3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape_on_memlet_path(): + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def single_state_reshape_memlet_path(inp1: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp1, shape=target_shape, reshaped=reshaped) + Z = reshaped + bias + Zl = dace.elementwise(lambda x: log(x + 1), Z) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape_memlet_path.to_sdfg(simplify=False) + + sdfg.expand_library_nodes() + sdfg.apply_transformations_repeated([StateFusion]) + + donnx.default_implementation = old_default + + def torch_func(*, inp1, bias): + reshaped = torch.reshape(inp1, [3, 3]) + + Z = reshaped + bias + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp1=inp1.grad, gradient_bias=bias.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict( + inp1=np.random.rand(9).astype(np.float64), + bias=np.random.rand(3).astype(np.float64), + ), + ) + + +@pytest.mark.autodiff +@run_correctness +def test_reshape_reuse_in_same_state(): + old_default = donnx.default_implementation + donnx.default_implementation = "pure" + + @dace.program + def single_state_reshape_same_state(inp: dace.float64[9], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + Zl = dace.elementwise(lambda x: log(x + 1), reshaped) + S = np.sum(Zl) + return S + + sdfg = single_state_reshape_same_state.to_sdfg(simplify=False) + + sdfg.expand_library_nodes() + sdfg.apply_transformations_repeated([StateFusion]) + + donnx.default_implementation = old_default + + def torch_func(*, inp): + reshaped = torch.reshape(inp, [3, 3]) + + Z = reshaped + Zl = torch.log(Z + 1) + S = Zl.sum() + + S.backward() + return dict(gradient_inp=inp.grad) + + return ( + SDFGBackwardRunner(sdfg, "__return", simplify=False), + torch_func, + dict(inp=np.random.rand(9).astype(np.float64), ), + ) + + +if __name__ == "__main__": + test_gemm() + test_sum() + test_complex_tasklet() + test_tasklets_only_reuse() + test_tasklets_multioutput() + test_tasklets_only() + test_add_mmul_transpose_log() + test_reduce_node_1_axis_and_none_axis() + test_reduce_max_simple() + test_reduce_max_node_1_axis() + test_reshape() + test_reshape_on_memlet_path() + test_reshape_reuse_in_same_state() diff --git a/tests/autodiff/torch_backward/test_dont_compute_input_grads.py b/tests/autodiff/torch_backward/test_dont_compute_input_grads.py new file mode 100644 index 0000000000..d64a9e0400 --- /dev/null +++ b/tests/autodiff/torch_backward/test_dont_compute_input_grads.py @@ -0,0 +1,68 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_skip_input_grads(use_cpp_dispatcher: bool): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.rand(10, 10)) + + def forward(self, x): + return x @ self.fc1 + + dace_module = Module() + pt_module = Module() + pt_module.load_state_dict(dace_module.state_dict()) + + shape = [8, 10] + input_value = torch.rand(*shape, dtype=torch.float32) + + pytorch_input = torch.empty( + *shape, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + dace_input = torch.empty(*shape, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + # TODO: provide a better API for input names + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_module = DaceModule(dace_module, + sdfg_name=f"test_skip_input_grads_{dispatcher_suffix}", + backward=True, + inputs_to_skip=["onnx::MatMul_0"], + compile_torch_extension=use_cpp_dispatcher) + + dy = torch.rand(8, 10) + + dace_output = dace_module(dace_input) + pt_output = pt_module(pytorch_input) + + torch_tensors_close("output", pt_output, dace_output) + + # check that fc1.grad is being computed + dace_output.backward(dy) + pt_output.backward(dy) + torch_tensors_close("param_grad", pt_module.fc1.grad, dace_module.model.fc1.grad) + + # Make sure that input grad is not being computed + assert len(dace_module.backward_sdfg.node(0).sink_nodes()) == 1, \ + f"Expected 1 sink node (no input gradient), got {len(dace_module.backward_sdfg.node(0).sink_nodes())}" + + +if __name__ == "__main__": + test_skip_input_grads(use_cpp_dispatcher=True) + test_skip_input_grads(use_cpp_dispatcher=False) diff --git a/tests/autodiff/torch_backward/test_dropout.py b/tests/autodiff/torch_backward/test_dropout.py new file mode 100644 index 0000000000..7e360ccf26 --- /dev/null +++ b/tests/autodiff/torch_backward/test_dropout.py @@ -0,0 +1,69 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Literal, Union +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_dropout_fwd_training(): + p = 0.5 + module = nn.Dropout(p=p).train() + dace_module = DaceModule(module, + sdfg_name="test_dropout_fwd_training", + dummy_inputs=(torch.ones(10, 10), ), + training=True) + + # dropout will set some of these to zero + test_data = torch.randint(1, 10, (10, 10)).float() + + out = dace_module(torch.clone(test_data)) + zeroed = out == 0 + + scale = 1 / (1 - p) + torch_tensors_close("output", test_data[~zeroed] * scale, out[~zeroed]) + + +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.parametrize("p", [0, 0.99, 0.6, 0.5]) +def test_dropout_bwd(p: Union[float, Literal[0]]): + module = nn.Dropout(p=p).train() + sdfg_name = f"test_dropout_{str(p).replace('.', '_')}_bwd" + dace_module = DaceModule(module, + sdfg_name=sdfg_name, + dummy_inputs=(torch.ones(10, 10), ), + backward=True, + training=True) + + test_data = torch.randint(1, 10, (10, 10)).float() + test_data.requires_grad = True + dy = torch.rand_like(test_data) + + out = dace_module(torch.clone(test_data)) + + zeroed = out == 0 + scale = 1 / (1 - p) + # check that fwd was correct + torch_tensors_close("output", test_data[~zeroed] * scale, out[~zeroed]) + + out.backward(dy) + + # check that the gradient is correct: + zeros = torch.zeros_like(test_data.grad) + # check that zeroed values are zero in the grad + torch_tensors_close("grad_zeroed", zeros[zeroed], test_data.grad[zeroed]) + + # check that non-zeroed values are correct + torch_tensors_close("grad_zeroed", dy[~zeroed] * scale, test_data.grad[~zeroed]) + + +if __name__ == "__main__": + test_dropout_fwd_training() + # Test with different dropout probabilities + for p in [0, 0.99, 0.6, 0.5]: + test_dropout_bwd(p=p) diff --git a/tests/autodiff/torch_backward/test_extremal_reduction_backward.py b/tests/autodiff/torch_backward/test_extremal_reduction_backward.py new file mode 100644 index 0000000000..8758cdec43 --- /dev/null +++ b/tests/autodiff/torch_backward/test_extremal_reduction_backward.py @@ -0,0 +1,160 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed") +import torch + +import dace +from tests.autodiff.test_single_state import SDFGBackwardRunner + + +def run_max_reduction_test(dace_func, torch_func, inputs, rtol=1e-5, atol=1e-5): + sdfg = dace_func.to_sdfg() + runner = SDFGBackwardRunner(sdfg, "__return") + + sdfg_dict = {name: arr.copy() for name, arr in inputs.items()} + torch_dict = {name: torch.tensor(arr.copy(), requires_grad=True) for name, arr in inputs.items()} + + sdfg_results = runner.run(**sdfg_dict) + torch_results = torch_func(**torch_dict) + + for k, v in torch_results.items(): + v = v.detach().numpy() + assert np.allclose(sdfg_results[k], v, rtol=rtol, atol=atol), \ + f"Gradient mismatch for '{k}':\n DaCe: {sdfg_results[k]}\n PyTorch: {v}" + + +@pytest.mark.autodiff +def test_max_single_maximum(): + """Max reduction with single maximum - no ties.""" + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[4]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([1.0, 3.0, 2.0, 0.0], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_tied_values_2d(): + """Max reduction with tied values along an axis. + + For input [[1, 3], [3, 2]] with max along axis=0: + - Column 0: max=3 at row 1 only -> grad [0, 1] + - Column 1: max=3 at row 0 only -> grad [1, 0] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[2, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[1.0, 3.0], [3.0, 2.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_tied_values_same_column(): + """Max reduction with tied values in the same column. + + For input [[3, 1], [3, 2]] with max along axis=0: + - Column 0: max=3 at rows 0 AND 1 -> split grad equally: [0.5, 0.5] + - Column 1: max=2 at row 1 only -> grad [0, 1] + + Expected gradient: [[0.5, 0], [0.5, 1]] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[2, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[3.0, 1.0], [3.0, 2.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_max_all_equal_column(): + """Max reduction where entire column has same value. + + For input [[3, 1], [3, 2], [3, 0]] with max along axis=0: + - Column 0: all values are 3 -> split equally: [1/3, 1/3, 1/3] + - Column 1: max=2 at row 1 only -> grad [0, 1, 0] + + Expected gradient: [[1/3, 0], [1/3, 1], [1/3, 0]] + """ + + def torch_func(*, W): + Z = torch.amax(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[3, 2]): + Z = np.max(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[3.0, 1.0], [3.0, 2.0], [3.0, 0.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +@pytest.mark.autodiff +def test_min_tied_values(): + """Min reduction with tied values. + + For input [[1, 2], [1, 3], [2, 1]] with min along axis=0: + - Column 0: min=1 at rows 0 AND 1 -> split: [0.5, 0.5, 0] + - Column 1: min=1 at row 2 only -> grad [0, 0, 1] + + Expected gradient: [[0.5, 0], [0.5, 0], [0, 1]] + """ + + def torch_func(*, W): + Z = torch.amin(W, dim=0) + S = Z.sum() + S.backward() + return dict(gradient_W=W.grad) + + @dace.program + def dace_func(W: dace.float32[3, 2]): + Z = np.min(W, axis=0) + S = np.sum(Z) + return S + + inputs = dict(W=np.array([[1.0, 2.0], [1.0, 3.0], [2.0, 1.0]], dtype=np.float32)) + run_max_reduction_test(dace_func, torch_func, inputs) + + +if __name__ == "__main__": + test_max_single_maximum() + test_max_tied_values_2d() + test_max_tied_values_same_column() + test_max_all_equal_column() + test_min_tied_values() diff --git a/tests/autodiff/torch_backward/test_full_training_graph.py b/tests/autodiff/torch_backward/test_full_training_graph.py new file mode 100644 index 0000000000..9e97b10b65 --- /dev/null +++ b/tests/autodiff/torch_backward/test_full_training_graph.py @@ -0,0 +1,227 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import copy + +import pytest + +import numpy as np + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +import dace + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close, tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_module(): + gpu = False + module = torch.nn.Sequential(torch.nn.Linear(12, 2, bias=False)) + + torch_module = copy.deepcopy(module) + dace_module = copy.deepcopy(module) + + dace_module = DaceModule(dace_module, + sdfg_name="test_full_training_graph_module", + simplify=False, + backward=True, + training=True, + auto_optimize=False) + + x = torch.randn(8, 12) + + expected_output = torch_module(x) + result = dace_module(x) + torch_tensors_close('output', expected_output, result) + + dc_loss = dace_module(x).sum() + dc_loss.backward() + + pt_loss = torch_module(x).sum() + pt_loss.backward() + + tensors_close("loss", pt_loss, dc_loss) + assert all(hasattr(p, 'grad') and p.grad is not None for p in dace_module.parameters()), \ + "Not all parameters have gradients computed" + + for d, t in zip(dace_module.parameters(), torch_module.parameters()): + torch_tensors_close("param", t.grad, d.grad) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_simple(): + x = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5], dy: dace.float64[10]): + x.requires_grad_() + red = np.add.reduce(x, axis=1) + torch.autograd.backward(red, dy) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone(), dy.clone()) + tensors_close('x.grad', dy.reshape(10, 1).expand(10, 5), result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_scalar(): + x = torch.randn(10, 5, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5]): + x.requires_grad_() + red = np.add.reduce(x, axis=[0, 1]) + torch.autograd.backward(red) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone()) + tensors_close('x.grad', 1, result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_parse_backward_with_forwarding(): + x = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + @dace.program + def train_step(x: dace.float64[10, 5]): + x.requires_grad_() + y = x + 1 + red = np.add.reduce(x, axis=1, keepdims=True) + z = red * y + loss = np.add.reduce(z, axis=[0, 1]) + torch.autograd.backward(loss) + return x.grad + + def torch_fn(x): + x.requires_grad_() + y = x + 1 + red = x.sum(axis=1, keepdims=True) + z = red * y + loss = z.sum() + loss.backward() + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.expand_library_nodes() + sdfg.validate() + + result = sdfg(x.clone()) + expected = torch_fn(x.clone()) + tensors_close('x.grad', expected, result) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_two_backward_passes(): + + @dace.program + def train_step(x1: dace.float64[10, 5], x2: dace.float64[5], dy: dace.float64[10]): + x1.requires_grad_() + x2.requires_grad_() + + z1 = x1 + 1 + y1 = np.log(z1) + l1 = np.add.reduce(y1, axis=1) + + z2 = x2 * 2 + y2 = np.log(z2) + l2 = y2.sum() + + l2.backward() + l1.backward(dy) + return x1.grad, x2.grad + + def torch_fn(x1, x2, dy): + x1.requires_grad_() + x2.requires_grad_() + z1 = x1 + 1 + y1 = torch.log(z1).sum(axis=1) + + z2 = x2 * 2 + y2 = torch.log(z2).sum() + y2.backward() + y1.backward(dy) + return x1.grad, x2.grad + + sdfg = train_step.to_sdfg() + sdfg.validate() + sdfg.expand_library_nodes() + sdfg.validate() + + x1 = torch.randn(10, 5, dtype=torch.float64) + x2 = torch.randn(5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + r1, r2 = sdfg(x1.clone(), x2.clone(), dy.clone()) + ex_1, ex_2 = torch_fn(x1.clone(), x2.clone(), dy.clone()) + tensors_close('x2.grad', ex_2, r2) + tensors_close('x1.grad', ex_1, r1) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_two_backward_passes_accumulate(): + + @dace.program + def train_step(x: dace.float64[10, 5], dy: dace.float64[10]): + x.requires_grad_() + + z1 = x + 1 + y1 = np.log(z1) + l1 = np.add.reduce(y1, axis=1) + + z2 = x * 2 + y2 = np.log(z2) + l2 = y2.sum() + + l2.backward() + l1.backward(dy) + return x.grad + + def torch_fn(x, dy): + x.requires_grad = True + z1 = x + 1 + y1 = torch.log(z1).sum(axis=1) + + z2 = x * 2 + y2 = torch.log(z2).sum() + y2.backward() + y1.backward(dy) + return x.grad + + sdfg = train_step.to_sdfg() + sdfg.validate() + sdfg.expand_library_nodes() + sdfg.validate() + + x1 = torch.randn(10, 5, dtype=torch.float64) + dy = torch.randn(10, dtype=torch.float64) + + result = sdfg(x1.clone(), dy.clone()) + expected = torch_fn(x1.clone(), dy.clone()) + + tensors_close('x.grad', expected, result) + + +if __name__ == "__main__": + test_module() + test_parse_backward_simple() + test_parse_backward_scalar() + test_parse_backward_with_forwarding() + test_two_backward_passes() + test_two_backward_passes_accumulate() diff --git a/tests/autodiff/torch_backward/test_llama_decoder_backward.py b/tests/autodiff/torch_backward/test_llama_decoder_backward.py new file mode 100644 index 0000000000..79cb8621b1 --- /dev/null +++ b/tests/autodiff/torch_backward/test_llama_decoder_backward.py @@ -0,0 +1,107 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaConfig +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +# Create a wrapper module that handles the position embeddings internally +class LlamaDecoderLayerWrapper(nn.Module): + + def __init__(self, decoder_layer, config): + super().__init__() + self.decoder_layer = decoder_layer + self.config = config + + # Create rotary embeddings as part of the wrapper + from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding + self.rotary_emb = LlamaRotaryEmbedding(config) + + def forward(self, hidden_states, attention_mask, position_ids): + # Generate position embeddings + cos, sin = self.rotary_emb(hidden_states, position_ids) + + # Call the decoder layer + outputs = self.decoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=(cos, sin), + past_key_value=None, + output_attentions=False, + use_cache=False, + ) + + # Return only the hidden states (first element of the tuple) + return outputs[0] + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +def test_llama_decoder_backward(): + # Create configuration + config = LlamaConfig( + hidden_size=512, + intermediate_size=1024, + num_attention_heads=8, + num_key_value_heads=8, + max_position_embeddings=128, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_dropout=0.0, + hidden_act="silu", + ) + + # Create decoder layer + decoder_layer = LlamaDecoderLayer(config, layer_idx=0) + + # Prepare dummy inputs + batch_size = 2 + seq_length = 128 + + # Create input tensors + hidden_states = torch.randn(batch_size, seq_length, config.hidden_size) + attention_mask = torch.ones(batch_size, 1, seq_length, seq_length) + position_ids = torch.arange(seq_length).unsqueeze(0).expand(batch_size, seq_length) + + # Create wrapped model + wrapped_model = LlamaDecoderLayerWrapper(decoder_layer, config) + + # Avoid the simplify pass since it takes too long for this model + dace_model = DaceModule( + wrapped_model, + sdfg_name="test_llama_decoder_backward", + onnx_simplify=True, + backward=True, + ) + + hidden_states_pt, attention_mask_pt, position_ids_pt = (torch.clone(hidden_states), torch.clone(attention_mask), + torch.clone(position_ids)) + hidden_states_pt.requires_grad = True + + hidden_states_dace, attention_mask_dace, position_ids_dace = (torch.clone(hidden_states), + torch.clone(attention_mask), + torch.clone(position_ids)) + hidden_states_dace.requires_grad = True + + wrapped_model(hidden_states_pt, attention_mask_pt, position_ids_pt).sum().backward() + dace_model(hidden_states_dace, attention_mask_dace, position_ids_dace).sum().backward() + + # Check gradients of the parameters + for (name, dace_param), (pt_name, pt_param) in zip(wrapped_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, dace_param.grad, pt_param.grad) + + # Check the gradients of the input tensor + torch_tensors_close("hidden_states_pt_grad", hidden_states_pt.grad, hidden_states_dace.grad) + + +if __name__ == "__main__": + test_llama_decoder_backward() diff --git a/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py b/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py new file mode 100644 index 0000000000..a1ec63a3c8 --- /dev/null +++ b/tests/autodiff/torch_backward/test_llama_for_causalLM_backward.py @@ -0,0 +1,105 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +import torch.nn as nn +from transformers import LlamaForCausalLM, LlamaConfig +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class LlamaWrapper(nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + self.config = model.config + + def forward(self, input_ids): + # Get the embeddings + inputs_embeds = self.model.model.embed_tokens(input_ids) + + # Create position ids + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + + # Process through decoder layers + hidden_states = inputs_embeds + + # Create causal mask for attention + causal_mask = torch.triu(torch.ones((seq_length, seq_length), device=input_ids.device), diagonal=1) + causal_mask = causal_mask.masked_fill(causal_mask == 1, float('-inf')) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + # Forward through each layer + for decoder_layer in self.model.model.layers: + # Get rotary embeddings + cos, sin = self.model.model.rotary_emb(hidden_states, position_ids) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + position_embeddings=(cos, sin), + ) + hidden_states = layer_outputs[0] + + # Final layer norm + hidden_states = self.model.model.norm(hidden_states) + + # Get logits + logits = self.model.lm_head(hidden_states) + + return logits + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.long +def test_llama_model_backward(): + # Create a small LLaMA configuration + config = LlamaConfig( + vocab_size=32000, + hidden_size=512, + intermediate_size=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=8, + max_position_embeddings=128, + rms_norm_eps=1e-5, + rope_theta=10000.0, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + # Create the full model + model = LlamaForCausalLM(config) + export_seq_length = 16 + export_batch_size = 1 + input = torch.randint(3, config.vocab_size, (export_batch_size, export_seq_length)) + + wrapped_model = LlamaWrapper(model) + + # Avoid the simplify pass since it takes too long for this model + dace_model = DaceModule(wrapped_model, sdfg_name="test_llama_model_backward", backward=True, onnx_simplify=True) + + wrapped_model(input.clone()).sum().backward() + dace_model(input.clone()).sum().backward() + + # Check gradients of the parameters + for (name, dace_param), (pt_name, pt_param) in zip(wrapped_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.grad, dace_param.grad) + + +if __name__ == "__main__": + test_llama_model_backward() diff --git a/tests/autodiff/torch_backward/test_multi_output_ad.py b/tests/autodiff/torch_backward/test_multi_output_ad.py new file mode 100644 index 0000000000..057813add3 --- /dev/null +++ b/tests/autodiff/torch_backward/test_multi_output_ad.py @@ -0,0 +1,64 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_multi_output(use_cpp_dispatcher: bool): + + class Module(torch.nn.Module): + + def forward(self, x): + return x + 1, x * 2 + + module = Module() + + input_value = torch.rand(5, 10, dtype=torch.float32) + + pytorch_input = torch.empty( + 5, + 10, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + + dace_input = torch.empty(5, 10, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + pytorch_input.requires_grad = True + dace_input.requires_grad = True + + torch_dy = torch.randn(5, 10, dtype=torch.float32) + dace_dy = torch_dy.clone() + + pytorch_y1, pytorch_y2 = module(pytorch_input) + + pytorch_y1.backward(torch_dy) + pytorch_y2.backward(torch_dy) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_module = DaceModule( + module, + sdfg_name=f"test_multi_output_ad_{dispatcher_suffix}", + backward=True, + compile_torch_extension=use_cpp_dispatcher, + ) + + dace_y1, dace_y2 = dace_module(dace_input) + + dace_y1.backward(dace_dy, retain_graph=True) + dace_y2.backward(dace_dy) + + torch_tensors_close("grad", pytorch_input.grad, dace_input.grad) + + +if __name__ == "__main__": + test_multi_output(use_cpp_dispatcher=True) + test_multi_output(use_cpp_dispatcher=False) diff --git a/tests/autodiff/torch_backward/test_pytorch.py b/tests/autodiff/torch_backward/test_pytorch.py new file mode 100644 index 0000000000..05d2f6948a --- /dev/null +++ b/tests/autodiff/torch_backward/test_pytorch.py @@ -0,0 +1,305 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import numpy as np +import pytest +import copy + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +@pytest.mark.autodiff +def run_pytorch_module( + module: torch.nn.Module, + sdfg_name: str, + shape: tuple = None, + use_max: bool = False, + auto_optimize: bool = False, + rtol: float = 1e-4, + atol: float = 1e-3, + post_onnx_hooks: list = None, +): + shape = shape or (3, 5) + + pt_model_for_dace = copy.deepcopy(module) + + input_value = torch.rand(*shape, dtype=torch.float32) + + pytorch_input = torch.empty( + *shape, + dtype=torch.float32, + requires_grad=False, + ) + pytorch_input.copy_(input_value) + + dace_input = torch.empty(*shape, dtype=torch.float32, requires_grad=False) + dace_input.copy_(input_value) + + pytorch_input.requires_grad = True + dace_input.requires_grad = True + + if use_max: + pytorch_s = module(pytorch_input).max() + else: + pytorch_s = module(pytorch_input).sum() + pytorch_s.backward() + + dace_module = DaceModule( + pt_model_for_dace, + sdfg_name=sdfg_name, + simplify=False, + backward=True, + auto_optimize=auto_optimize, + compile_torch_extension=True, + ) + if post_onnx_hooks is not None: + for i, h in enumerate(post_onnx_hooks): + dace_module.append_post_onnx_hook(str(i), h) + + if use_max: + dace_s = dace_module(dace_input).max() + else: + dace_s = dace_module(dace_input).sum() + dace_s.backward() + torch_tensors_close("grad", pytorch_input.grad, dace_input.grad, rtol=rtol, atol=atol) + + for (name, dace_param), (pt_name, pt_param) in zip(module.named_parameters(), dace_module.named_parameters()): + assert 'model.' + name == pt_name + torch_tensors_close(name, pt_param.grad, dace_param.grad, rtol=rtol, atol=atol) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_simple(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = torch.sqrt(x) + x = torch.log(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_simple") + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_repeated(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = torch.sqrt(x) + x = torch.sqrt(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_repeated") + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_softmax(): + + class Module(torch.nn.Module): + + def forward(self, x): + x = F.softmax(x, dim=1) + return x + + run_pytorch_module(Module(), sdfg_name="test_softmax", use_max=True) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_reshape_on_memlet_path(): + # required test: this function in a nn.Module, with apply simplify so that the reshape is + # inlined and copy is removed + class Module(torch.nn.Module): + + def forward(self, x): + reshaped = torch.reshape(x + 1, [3, 3]) + return torch.log(reshaped) + torch.reshape(torch.tensor([[3, 2, 1]], device=reshaped.device), [3]) + + run_pytorch_module(Module(), sdfg_name="test_reshape_on_memlet_path", shape=(9, )) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_weights_ln(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Linear(784, 120) + self.fc2 = nn.Linear(120, 32) + self.ln = nn.LayerNorm(32) + self.fc3 = nn.Linear(32, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.ln(x) + x = self.fc3(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_weights_ln", shape=(4, 784), auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_layernorm(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.ln = nn.LayerNorm(3) + + def forward(self, x): + return self.ln(x) + + run_pytorch_module(Module(), sdfg_name="test_layernorm", shape=(2, 3), use_max=True, atol=1e-2) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_weights(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Linear(784, 120) + self.fc2 = nn.Linear(120, 32) + self.fc3 = nn.Linear(32, 10) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + run_pytorch_module(Module(), sdfg_name="test_weights", shape=(4, 784), use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_nested_gradient_summation(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.rand(10, 10)) + + def forward(self, x): + y = x @ self.fc1 + z = x * 2 + return z + y + + run_pytorch_module(Module(), + sdfg_name="test_nested_gradient_summation", + shape=(4, 10), + use_max=False, + auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_trans_add(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + + def forward(self, x): + x = x + 1 + x = torch.transpose(x.reshape(4, 4), 1, 0) + return x + + run_pytorch_module(Module(), sdfg_name="test_trans_add", shape=(16, ), use_max=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_batched_matmul(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.fc1 = nn.Parameter(torch.ones([10, 5, 3])) + + def forward(self, x): + return self.fc1 @ x + + run_pytorch_module(Module(), sdfg_name="test_batched_matmul", use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_scalar_forwarding(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.factor = nn.Parameter(torch.ones(())) + + def forward(self, x): + return self.factor * x + + run_pytorch_module(Module(), sdfg_name="test_scalar_forwarding", use_max=False, auto_optimize=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_scalar_buffer(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.register_buffer("factor", torch.tensor(2)) + + def forward(self, x): + return self.factor * x + + run_pytorch_module(Module(), sdfg_name="test_scalar_buffer", use_max=False) + + +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.skip(reason="Requires pure implementation of expand") +def test_simple_broadcasted_mul(): + + class Module(torch.nn.Module): + + def forward(self, x): + y = x.sum(axis=0) + return x * y + + run_pytorch_module(Module(), sdfg_name="test_simple_broadcasted_mul") + + +if __name__ == "__main__": + test_simple() + test_repeated() + test_softmax() + test_reshape_on_memlet_path() + test_weights_ln() + test_layernorm() + test_weights() + test_nested_gradient_summation() + test_trans_add() + test_batched_matmul() + test_scalar_forwarding() + test_scalar_buffer() + # test_simple_broadcasted_mul is skipped diff --git a/tests/autodiff/torch_backward/test_training.py b/tests/autodiff/torch_backward/test_training.py new file mode 100644 index 0000000000..0858288145 --- /dev/null +++ b/tests/autodiff/torch_backward/test_training.py @@ -0,0 +1,124 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os +import copy +import pytest + +import numpy as np + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import torch +from torch import nn, optim +from transformers import BertConfig +from transformers.models.bert.modeling_bert import BertLayer + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +def training_step( + dace_model: torch.nn.Module, + pt_model: torch.nn.Module, + train_batch: tuple, + sdfg_name: str, + train_criterion: torch.nn.Module = None, +): + + # Copy over the weights + dace_model.load_state_dict(pt_model.state_dict()) + for dace_value, value in zip(pt_model.state_dict().values(), dace_model.state_dict().values()): + assert torch.allclose(dace_value, value), "State dict copy verification failed" + + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, backward=True, simplify=True, training=True) + + x, y = train_batch + + train_criterion = train_criterion or nn.NLLLoss() + + pt_loss = train_criterion(pt_model(x), y) + + dace_output = dace_model(x) + dace_loss = train_criterion(dace_output, y) + + diff = abs(pt_loss.item() - dace_loss.item()) / pt_loss.item() + assert diff < 1e-5, f"Loss mismatch: relative difference {diff:.2e} exceeds tolerance 1e-5" + + pt_loss.backward() + dace_loss.backward() + + for (name, dace_param), (pt_name, pt_param) in zip(pt_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.grad, dace_param.grad) + + optimizer = optim.SGD(pt_model.parameters(), lr=0.001) + dace_optimizer = optim.SGD(dace_model.parameters(), lr=0.001) + optimizer.step() + dace_optimizer.step() + + for (name, dace_param), (pt_name, pt_param) in zip(pt_model.named_parameters(), dace_model.named_parameters()): + assert 'model.' + name == pt_name, f"Parameter name mismatch after optimizer step: expected 'model.{name}', got '{pt_name}'" + torch_tensors_close(name, pt_param.detach(), dace_param.detach()) + + +@pytest.mark.torch +@pytest.mark.autodiff +def test_mnist(): + input_size = 784 + hidden_sizes = [128, 64] + output_size = 10 + + # initialize modules + # yapf: disable + model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LayerNorm(output_size), + nn.LogSoftmax(dim=1)) + + dace_model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]), + nn.ReLU(), + nn.Linear(hidden_sizes[0], hidden_sizes[1]), + nn.ReLU(), + nn.Linear(hidden_sizes[1], output_size), + nn.LayerNorm(output_size), + nn.LogSoftmax(dim=1)) + + # check forward pass using loss + images = torch.randn(64, 784) + labels = torch.randint(0, 10, [64], dtype=torch.long) + + training_step(dace_model, model, (images, labels), sdfg_name="test_mnist_training") + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.torch +@pytest.mark.autodiff +@pytest.mark.skip(reason="Requires pure implementation of expand") +def test_bert(): + batch_size = 2 + seq_len = 512 + hidden_size = 768 + + class BertTokenSoftmaxClf(nn.Module): + + def __init__(self): + super(BertTokenSoftmaxClf, self).__init__() + self.bert = BertLayer(BertConfig(hidden_act="relu")).eval() + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, x): + embs = self.bert(x)[0] + return self.sm(embs.sum(dim=-1)) + + # check forward pass using loss + input = torch.randn([batch_size, seq_len, hidden_size]) + labels = torch.tensor([0, 123], dtype=torch.long) + + training_step(BertTokenSoftmaxClf(), BertTokenSoftmaxClf(), (input, labels), sdfg_name="test_bert_training") + + +if __name__ == "__main__": + test_mnist() + # test_bert is skipped diff --git a/tests/conftest.py b/tests/conftest.py index 57f611ce66..8fe2fb56f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,14 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): if config.option.markexpr == 'mpi': if exitstatus in (pytest.ExitCode.TESTS_FAILED, pytest.ExitCode.INTERNAL_ERROR, pytest.ExitCode.INTERRUPTED): os._exit(1) + + +def pytest_generate_tests(metafunc): + """ + This method sets up the parametrizations for the custom fixtures + """ + if "use_cpp_dispatcher" in metafunc.fixturenames: + metafunc.parametrize("use_cpp_dispatcher", [ + pytest.param(True, id="use_cpp_dispatcher"), + pytest.param(False, id="no_use_cpp_dispatcher"), + ]) diff --git a/tests/npbench/deep_learning/conv2d_bias_test.py b/tests/npbench/deep_learning/conv2d_bias_test.py index 648903ffb9..193c1f037c 100644 --- a/tests/npbench/deep_learning/conv2d_bias_test.py +++ b/tests/npbench/deep_learning/conv2d_bias_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass C_in, C_out, H, K, N, W = (dc.symbol(s, dc.int64) for s in ('C_in', 'C_out', 'H', 'K', 'N', 'W')) @@ -72,6 +73,43 @@ def conv2d_bias_np(input, weights, bias): return conv2d_np(input, weights) + bias +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def conv2d_bias_jax_kernel(jnp, lax, input, weights, bias): + return jnp.sum(conv2d_lax(jnp, lax, input, weights) + bias) + + def run_conv2d_bias(device_type: dace.dtypes.DeviceType): ''' Runs conv2d_bias for the given device @@ -107,6 +145,52 @@ def run_conv2d_bias(device_type: dace.dtypes.DeviceType): return sdfg +def run_conv2d_bias_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, C_in, C_out, K, H, W = 4, 3, 8, 2, 12, 12 + input, weights, bias = initialize(C_in, C_out, H, K, N, W) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, C_in], weights: dc.float32[K, K, C_in, C_out], + bias: dc.float32[C_out]): + A = conv2d(input, weights) + bias + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + weights, + bias, + C_in=C_in, + C_out=C_out, + H=H, + K=K, + N=N, + W=W, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Enable float32 for JAX to match DaCe consistency + jax.config.update("jax_enable_x64", False) + + # Numerically validate vs JAX + jax_kernel = lambda input, weights, bias: conv2d_bias_jax_kernel(jnp, lax, input, weights, bias) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, weights, bias) + np.testing.assert_allclose(gradient_input, jax_grad_input, atol=1e-6, rtol=1e-6) + + def test_cpu(): run_conv2d_bias(dace.dtypes.DeviceType.CPU) @@ -117,6 +201,12 @@ def test_gpu(): run_conv2d_bias(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_conv2d_bias_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_conv2d_bias(dace.dtypes.DeviceType.FPGA) @@ -132,6 +222,7 @@ def test_fpga(): if target == "cpu": run_conv2d_bias(dace.dtypes.DeviceType.CPU) + run_conv2d_bias_autodiff() elif target == "gpu": run_conv2d_bias(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/deep_learning/lenet_test.py b/tests/npbench/deep_learning/lenet_test.py index 37cba9af9b..ef28755319 100644 --- a/tests/npbench/deep_learning/lenet_test.py +++ b/tests/npbench/deep_learning/lenet_test.py @@ -5,12 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test +from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import temporary_config, Config -import os +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N, H, W, C_before_fc1, S0, S1, S2, S3, S4, S5 = (dc.symbol(s, dtype=dc.int64) for s in ('N', 'H', 'W', 'C_before_fc1', 'S0', 'S1', 'S2', 'S3', 'S4', @@ -146,6 +144,74 @@ def lenet5_np(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, return x @ fc3w + fc3b +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def maxpool2d_lax(jnp, lax, x): + output = jnp.empty([x.shape[0], x.shape[1] // 2, x.shape[2] // 2, x.shape[3]], dtype=x.dtype) + + def row_update(output, i): + + def col_update(output, j): + input_slice = lax.dynamic_slice(x, (0, 2 * i, 2 * j, 0), (x.shape[0], 2, 2, x.shape[3])) + output = lax.dynamic_update_slice(output, jnp.max(input_slice, axis=(1, 2))[:, None, None, :], (0, i, j, 0)) + return output, None + + output, _ = lax.scan(col_update, output, jnp.arange(x.shape[2] // 2)) + return output, None + + output, _ = lax.scan(row_update, output, jnp.arange(x.shape[1] // 2)) + + return output + + +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +def lenet_jax_kernel(jnp, lax, input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b): + C_before_fc1 = fc1w.shape[0] + N = input.shape[0] + x = jax_relu(jnp, conv2d_lax(jnp, lax, input, conv1) + conv1bias) + x = maxpool2d_lax(jnp, lax, x) + x = jax_relu(jnp, conv2d_lax(jnp, lax, x, conv2) + conv2bias) + x = maxpool2d_lax(jnp, lax, x) + x = jnp.reshape(x, (N, C_before_fc1)) + x = jax_relu(jnp, x @ fc1w + fc1b) + x = jax_relu(jnp, x @ fc2w + fc2b) + return jnp.sum(x @ fc3w + fc3b) + + def run_lenet(device_type: dace.dtypes.DeviceType): ''' Runs lenet for the given device @@ -195,6 +261,58 @@ def run_lenet(device_type: dace.dtypes.DeviceType): return sdfg +def run_lenet_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, H, W = 4, 16, 16 + input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b, C_before_fc1 = initialize(N, H, W) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, 1], conv1: dc.float32[5, 5, 1, 6], conv1bias: dc.float32[6], + conv2: dc.float32[5, 5, 6, 16], conv2bias: dc.float32[16], fc1w: dc.float32[C_before_fc1, 120], + fc1b: dc.float32[120], fc2w: dc.float32[120, 84], fc2b: dc.float32[84], + fc3w: dc.float32[84, 10], fc3b: dc.float32[10]): + result = lenet5_kernel(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + return np.sum(result) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + conv1, + conv1bias, + conv2, + conv2bias, + fc1w, + fc1b, + fc2w, + fc2b, + fc3w, + fc3b, + N=N, + H=H, + W=W, + C_before_fc1=C_before_fc1, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b: lenet_jax_kernel( + jnp, lax, input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, conv1, conv1bias, conv2, conv2bias, fc1w, fc1b, fc2w, fc2b, fc3w, fc3b) + np.testing.assert_allclose(gradient_input, jax_grad_input, rtol=1e-6) + + def test_cpu(monkeypatch): # Serialization causes issues, we temporarily disable it monkeypatch.setenv("DACE_testing_serialization", 0) @@ -207,6 +325,12 @@ def test_gpu(): run_lenet(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_lenet_autodiff() + + @pytest.mark.skip(reason="Dynamic memory allocation") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -223,6 +347,7 @@ def test_fpga(): if target == "cpu": run_lenet(dace.dtypes.DeviceType.CPU) + run_lenet_autodiff() elif target == "gpu": run_lenet(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/deep_learning/mlp_test.py b/tests/npbench/deep_learning/mlp_test.py index 9588b66f68..28c36b4a79 100644 --- a/tests/npbench/deep_learning/mlp_test.py +++ b/tests/npbench/deep_learning/mlp_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass C_in, N, S0, S1, S2, N1, N2 = (dc.symbol(s, dtype=dc.int64) for s in ('C_in', 'N', 'S0', 'S1', 'S2', 'N1', 'N2')) @@ -78,6 +79,25 @@ def mlp_np(input, w1, b1, w2, b2, w3, b3): return x +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +# Numerically-stable version of softmax +def jax_softmax(jnp, x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return tmp_out / tmp_sum + + +def mlp_jax_kernel(jnp, input, w1, b1, w2, b2, w3, b3): + x = jax_relu(jnp, input @ w1 + b1) + x = jax_relu(jnp, x @ w2 + b2) + x = jax_softmax(jnp, x @ w3 + b3) # Softmax call can be omitted if necessary + return jnp.sum(x) + + def run_mlp(device_type: dace.dtypes.DeviceType): ''' Runs conv2d_bias for the given device @@ -115,6 +135,53 @@ def run_mlp(device_type: dace.dtypes.DeviceType): return sdfg +def run_mlp_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench test size) + C_in, N, S0, S1, S2 = 3, 8, 30, 20, 20 + input, w1, b1, w2, b2, w3, b3 = initialize(C_in, N, S0, S1, S2) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, C_in], w1: dc.float32[C_in, S0], b1: dc.float32[S0], + w2: dc.float32[S0, S1], b2: dc.float32[S1], w3: dc.float32[S1, S2], b3: dc.float32[S2]): + x1 = relu(input @ w1 + b1) + x2 = relu(x1 @ w2 + b2) + x3 = softmax(x2 @ w3 + b3) + return np.sum(x3) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + w1, + b1, + w2, + b2, + w3, + b3, + N=N, + S0=S0, + S1=S1, + S2=S2, + C_in=C_in, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, w1, b1, w2, b2, w3, b3: mlp_jax_kernel(jnp, input, w1, b1, w2, b2, w3, b3) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, w1, b1, w2, b2, w3, b3) + np.testing.assert_allclose(gradient_input, jax_grad_input, rtol=1e-4, atol=1e-10) + + def test_cpu(): run_mlp(dace.dtypes.DeviceType.CPU) @@ -124,6 +191,12 @@ def test_gpu(): run_mlp(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_mlp_autodiff() + + @pytest.mark.skip(reason="Intel, compilation error") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -140,6 +213,7 @@ def test_fpga(): if target == "cpu": run_mlp(dace.dtypes.DeviceType.CPU) + run_mlp_autodiff() elif target == "gpu": run_mlp(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/deep_learning/resnet_test.py b/tests/npbench/deep_learning/resnet_test.py index cfe43718e0..ba90e9a44b 100644 --- a/tests/npbench/deep_learning/resnet_test.py +++ b/tests/npbench/deep_learning/resnet_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass N, H, W, C1, C2, S0, S1, S2, S3, S4, S5 = (dc.symbol(s, dtype=dc.int64) for s in ('N', 'H', 'W', 'C1', 'C2', 'S0', 'S1', 'S2', 'S3', 'S4', 'S5')) @@ -168,6 +169,65 @@ def resnet_basicblock_np(input, conv1, conv2, conv3): return relu_np(x + input) +def conv2d_lax(jnp, lax, input, weights): + # Kernel size, number of input images, and output dimensions. + K = weights.shape[0] # Assuming square kernel of size K x K. + N = input.shape[0] # Batch size. + H_out = input.shape[1] - K + 1 # Output height. + W_out = input.shape[2] - K + 1 # Output width. + C_out = weights.shape[3] # Number of output channels. + + # Allocate output array. + output = jnp.empty((N, H_out, W_out, C_out), dtype=input.dtype) + + # Row update: iterate over output rows. + def row_update(out, i): + # Column update: iterate over output columns. + def col_update(out, j): + # Extract a patch from 'input' at the given (i, j) position. + patch = lax.dynamic_slice(input, (0, i, j, 0), (N, K, K, input.shape[-1])) + # Expand dims on the patch to broadcast with weights. + # weights: shape (K, K, in_channels, C_out) + # patch[..., None] becomes shape (N, K, K, in_channels, 1) + # We add a new leading dimension to weights to broadcast: + conv = jnp.sum(patch[..., None] * weights[None, :, :, :], axis=(1, 2, 3)) + # conv now has shape (N, C_out). Update output at (0, i, j, 0). + out = lax.dynamic_update_slice(out, conv[:, None, None, :], (0, i, j, 0)) + return out, None + + out, _ = lax.scan(col_update, out, jnp.arange(W_out)) + return out, None + + output, _ = lax.scan(row_update, output, jnp.arange(H_out)) + return output + + +def jax_relu(jnp, x): + return jnp.maximum(x, 0) + + +# Batch normalization operator, as used in ResNet +def jax_batchnorm2d(jnp, x, eps=1e-5): + mean = jnp.mean(x, axis=0, keepdims=True) + std = jnp.std(x, axis=0, keepdims=True) + return (x - mean) / jnp.sqrt(std + eps) + + +def resnet_jax_kernel(jnp, lax, input, conv1, conv2, conv3): + # Pad output of first convolution for second convolution + padded = jnp.zeros((input.shape[0], input.shape[1] + 2, input.shape[2] + 2, conv1.shape[3]), dtype=input.dtype) + padded = lax.dynamic_update_slice(padded, conv2d_lax(jnp, lax, input, conv1), (0, 1, 1, 0)) + x = jax_batchnorm2d(jnp, padded) + x = jax_relu(jnp, x) + + x = conv2d_lax(jnp, lax, x, conv2) + x = jax_batchnorm2d(jnp, x) + x = jax_relu(jnp, x) + x = conv2d_lax(jnp, lax, x, conv3) + x = jax_batchnorm2d(jnp, x) + return jnp.sum(jax_relu(jnp, x + input)) + + def run_resnet(device_type: dace.dtypes.DeviceType): ''' Runs resnet for the given device @@ -203,6 +263,53 @@ def run_resnet(device_type: dace.dtypes.DeviceType): return sdfg +def run_resnet_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench test size) + N, W, H, C1, C2 = 2, 8, 8, 8, 4 + input, conv1, conv2, conv3 = initialize(N, W, H, C1, C2) + + # Initialize gradient computation data + gradient_input = np.zeros_like(input, dtype=np.float32) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(input: dc.float32[N, H, W, C1], conv1: dc.float32[1, 1, C1, C2], + conv2: dc.float32[3, 3, C2, C2], conv3: dc.float32[1, 1, C2, C1]): + # Pad output of first convolution for second convolution + x = resnet_basicblock(input, conv1, conv2, conv3) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["input"], outputs=["__return"]) + + sdfg(input, + conv1, + conv2, + conv3, + N=N, + W=W, + H=H, + C1=C1, + C2=C2, + gradient_input=gradient_input, + gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda input, conv1, conv2, conv3: resnet_jax_kernel(jnp, lax, input, conv1, conv2, conv3) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_input = jax_grad(input, conv1, conv2, conv3) + + # The tolerance if fairly high with float32 inputs + # The same code using float64 works with a 1e-12 tolerance + np.testing.assert_allclose(gradient_input, jax_grad_input, atol=1e-2, rtol=1e-2) + + def test_cpu(): run_resnet(dace.dtypes.DeviceType.CPU) @@ -213,6 +320,12 @@ def test_gpu(): run_resnet(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_resnet_autodiff() + + @pytest.mark.skip(reason="Dynamic memory allocation") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -229,6 +342,7 @@ def test_fpga(): if target == "cpu": run_resnet(dace.dtypes.DeviceType.CPU) + run_resnet_autodiff() elif target == "gpu": run_resnet(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/deep_learning/softmax_test.py b/tests/npbench/deep_learning/softmax_test.py index 4408645adc..5e1e803cff 100644 --- a/tests/npbench/deep_learning/softmax_test.py +++ b/tests/npbench/deep_learning/softmax_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass N, H, SM = (dc.symbol(s, dc.int64) for s in ('N', 'H', 'SM')) @@ -37,6 +38,13 @@ def ground_truth(x): return tmp_out / tmp_sum +def softmax_jax_kernel(jnp, x): + tmp_max = jnp.max(x, axis=-1, keepdims=True) + tmp_out = jnp.exp(x - tmp_max) + tmp_sum = jnp.sum(tmp_out, axis=-1, keepdims=True) + return jnp.sum(tmp_out / tmp_sum) + + def run_softmax(device_type: dace.dtypes.DeviceType): ''' Runs Softmax for the given device @@ -72,6 +80,36 @@ def run_softmax(device_type: dace.dtypes.DeviceType): return sdfg +def run_softmax_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench test size) + N, H, SM = 4, 4, 32 + x = initialize(N, H, SM) + out = np.zeros_like(x) + + # Initialize gradient computation data + gradient_x = np.zeros_like(x) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def softmax_autodiff_kernel(x: dc.float32[N, H, SM, SM]): + return np.sum(softmax_kernel(x)) + + # Add the backward pass to the SDFG + sdfg = softmax_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["x"], outputs=["__return"]) + sdfg(x, out, N=N, H=H, SM=SM, gradient_x=gradient_x, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda x: softmax_jax_kernel(jnp, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_x = jax_grad(x) + np.testing.assert_allclose(gradient_x, jax_grad_x, atol=1e-6) + + def test_cpu(): run_softmax(dace.dtypes.DeviceType.CPU) @@ -81,6 +119,12 @@ def test_gpu(): run_softmax(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_softmax_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_softmax(dace.dtypes.DeviceType.FPGA) @@ -96,6 +140,7 @@ def test_fpga(): if target == "cpu": run_softmax(dace.dtypes.DeviceType.CPU) + run_softmax_autodiff() elif target == "gpu": run_softmax(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/misc/cavity_flow_test.py b/tests/npbench/misc/cavity_flow_test.py index d2e4d50f13..10d048e35f 100644 --- a/tests/npbench/misc/cavity_flow_test.py +++ b/tests/npbench/misc/cavity_flow_test.py @@ -8,6 +8,7 @@ from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.autodiff import add_backward_pass nx, ny, nit = (dace.symbol(s, dace.int64) for s in ('nx', 'ny', 'nit')) @@ -146,6 +147,67 @@ def initialize(ny, nx): return u, v, p, dx, dy, dt +def jax_build_up_b(b, rho, dt, u, v, dx, dy): + b = b.at[1:-1, 1:-1].set( + (rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + (v[2:, 1:-1] - v[0:-2, 1:-1]) / + (2 * dy)) - ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * + ((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * (v[1:-1, 2:] - v[1:-1, 0:-2]) / + (2 * dx)) - ((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))) + return b + + +def jax_pressure_poisson(jnp, nit, p, dx, dy, b): + pn = jnp.empty_like(p) + pn = p.copy() + + for q in range(nit): + pn = p.copy() + p = p.at[1:-1, 1:-1].set((((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + (pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / + (2 * (dx**2 + dy**2)) - dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])) + + p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2 + p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 + p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0 + p = p.at[-1, :].set(0) # p = 0 at y = 2 + return p + + +def cavity_flow_jax_kernel(jnp, nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu): + un = jnp.empty_like(u) + vn = jnp.empty_like(v) + b = jnp.zeros((ny, nx)) + + for n in range(nt): + un = u.copy() + vn = v.copy() + + b = jax_build_up_b(b, rho, dt, u, v, dx, dy) + p = jax_pressure_poisson(jnp, nit, p, dx, dy, b) + + u = u.at[1:-1, 1:-1].set( + (un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (un[1:-1, 1:-1] - un[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * + (un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * (p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * + (dt / dx**2 * (un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + dt / dy**2 * + (un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])))) + + v = v.at[1:-1, 1:-1].set( + (vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * (vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - vn[1:-1, 1:-1] * dt / dy * + (vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * (p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * + (dt / dx**2 * (vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + dt / dy**2 * + (vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))) + + u = u.at[0, :].set(0) + u = u.at[:, 0].set(0) + u = u.at[:, -1].set(0) + u = u.at[-1, :].set(1) + v = v.at[0, :].set(0) + v = v.at[-1, :].set(0) + v = v.at[:, 0].set(0) + v = v.at[:, -1].set(0) + + return jnp.sum(v) + + def run_cavity_flow(device_type: dace.dtypes.DeviceType): ''' Runs cavity-flow for the given device @@ -184,6 +246,58 @@ def run_cavity_flow(device_type: dace.dtypes.DeviceType): return sdfg +def run_cavity_flow_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (test size from benchmark) + ny, nx, nt, nit, rho, nu = (4, 4, 4, 5, 1.0, 0.1) + u, v, p, dx, dy, dt = initialize(ny, nx) + jax_u, jax_v, jax_p = u.copy(), v.copy(), p.copy() + + # Initialize gradient computation data + gradient_u = np.zeros_like(u) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define cavity flow kernel based on benchmark with __return pattern + @dace.program + def autodiff_kernel(nt: dace.int64, nit: dace.int64, u: dace.float64[ny, nx], v: dace.float64[ny, nx], + dt: dace.float64, dx: dace.float64, dy: dace.float64, p: dace.float64[ny, nx], + rho: dace.float64, nu: dace.float64): + + dace_cavity_flow(nt, nit, u, v, dt, dx, dy, p, rho, nu) + return np.sum(v) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["u"], outputs=["__return"]) + sdfg(nt, + nit, + u, + v, + dt, + dx, + dy, + p, + rho, + nu, + ny=ny, + nx=nx, + nit=nit, + gradient_u=gradient_u, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu: cavity_flow_jax_kernel( + jnp, nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4), static_argnums=(0, 1, 2, 3)) + jax_grad_u = jax_grad(nx, ny, nt, nit, jax_u, jax_v, dt, dx, dy, jax_p, rho, nu) + np.testing.assert_allclose(gradient_u, jax_grad_u, rtol=1e-6, atol=1e-10) + + def test_cpu(): run_cavity_flow(dace.dtypes.DeviceType.CPU) @@ -193,6 +307,12 @@ def test_gpu(): run_cavity_flow(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_cavity_flow_autodiff() + + @pytest.mark.skip(reason="Intel FPGA kernel arguments") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -209,6 +329,7 @@ def test_fpga(): if target == "cpu": run_cavity_flow(dace.dtypes.DeviceType.CPU) + run_cavity_flow_autodiff() elif target == "gpu": run_cavity_flow(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/misc/compute_test.py b/tests/npbench/misc/compute_test.py index feb5585ca5..e006f015b4 100644 --- a/tests/npbench/misc/compute_test.py +++ b/tests/npbench/misc/compute_test.py @@ -8,6 +8,7 @@ from dace.transformation.auto.auto_optimize import auto_optimize from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.autodiff import add_backward_pass def relerror(val, ref): diff --git a/tests/npbench/misc/go_fast_test.py b/tests/npbench/misc/go_fast_test.py index bf686b5b15..f547e53fa3 100644 --- a/tests/npbench/misc/go_fast_test.py +++ b/tests/npbench/misc/go_fast_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int64) @@ -66,6 +67,51 @@ def run_go_fast(device_type: dace.dtypes.DeviceType): return sdfg +def go_fast_jax_kernel(jnp, lax, a): + + def body_fn(trace, i): + # Update the trace by adding tanh(a[i, i]) + new_trace = trace + jnp.tanh(a[i, i]) + return new_trace, None # Return a dummy output. + + trace, _ = lax.scan(body_fn, 0.0, jnp.arange(a.shape[0])) + return jnp.sum(a + trace) + + +def run_go_fast_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize forward data (using smaller size for AD test) + N = 20 + a = initialize(N) + + # Initialize gradient computation data + gradient_a = np.zeros_like(a) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(a: dc.float64[N, N]): + result = go_fast_kernel(a) + return np.sum(result) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["a"], outputs=["__return"]) + sdfg(a, N=N, gradient_a=gradient_a, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda a: go_fast_jax_kernel(jnp, lax, a) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_a = jax_grad(a) + np.testing.assert_allclose(gradient_a, jax_grad_a) + + def test_cpu(): run_go_fast(dace.dtypes.DeviceType.CPU) @@ -75,6 +121,12 @@ def test_gpu(): run_go_fast(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_go_fast_autodiff() + + @pytest.mark.skip(reason="Operand type in binary expressions") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -91,6 +143,7 @@ def test_fpga(): if target == "cpu": run_go_fast(dace.dtypes.DeviceType.CPU) + run_go_fast_autodiff() elif target == "gpu": run_go_fast(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/adi_test.py b/tests/npbench/polybench/adi_test.py index 9cccf7da0f..d3be1a29a2 100644 --- a/tests/npbench/polybench/adi_test.py +++ b/tests/npbench/polybench/adi_test.py @@ -8,6 +8,7 @@ from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -43,7 +44,7 @@ def numpy_kernel(TSTEPS, N, u): e = 1.0 + mul2 f = d - for t in range(1, TSTEPS + 1): + for t in range(0, TSTEPS): v[0, 1:N - 1] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = v[0, 1:N - 1] @@ -89,7 +90,7 @@ def adi_kernel(TSTEPS: dace.int64, u: dace.float64[N, N]): e = 1.0 + mul2 f = d - for t in range(1, TSTEPS + 1): + for t in range(0, TSTEPS): v[0, 1:N - 1] = 1.0 p[1:N - 1, 0] = 0.0 q[1:N - 1, 0] = v[0, 1:N - 1] @@ -118,6 +119,80 @@ def initialize(N, datatype=np.float64): return u +def adi_jax_kernel(jnp, lax, TSTEPS, u): + N = u.shape[0] + v = jnp.zeros_like(u) + p = jnp.zeros_like(u) + q = jnp.zeros_like(u) + + DX = 1.0 / N + DY = 1.0 / N + DT = 1.0 / TSTEPS + B1 = 2.0 + B2 = 1.0 + mul1 = B1 * DT / (DX * DX) + mul2 = B2 * DT / (DY * DY) + a = -mul1 / 2.0 + b = 1.0 + mul2 + c = a + d = -mul2 / 2.0 + e = 1.0 + mul2 + f = d + + def first_j_scan(carry, j): + p, q, u = carry + + p = p.at[1:N - 1, j].set(-c / (a * p[1:N - 1, j - 1] + b)) + q = q.at[1:N - 1, + j].set((-d * u[j, 0:N - 2] + (1.0 + 2.0 * d) * u[j, 1:N - 1] - f * u[j, 2:N] - a * q[1:N - 1, j - 1]) / + (a * p[1:N - 1, j - 1] + b)) + return (p, q, u), None + + def first_backward_j_scan(carry, j): + v, p, q = carry + idx = N - 2 - j # reverse order index: when j=0, idx = N-2; when j=N-2, idx = 0. + v = v.at[idx, 1:N - 1].set(p[1:N - 1, idx] * v[idx + 1, 1:N - 1] + q[1:N - 1, idx]) + return (v, p, q), None + + def second_j_scan(carry, j): + p, q, v = carry + p = p.at[1:N - 1, j].set(-f / (d * p[1:N - 1, j - 1] + e)) + q = q.at[1:N - 1, + j].set((-a * v[0:N - 2, j] + (1.0 + 2.0 * a) * v[1:N - 1, j] - c * v[2:N, j] - d * q[1:N - 1, j - 1]) / + (d * p[1:N - 1, j - 1] + e)) + return (p, q, v), None + + def second_backward_j_scan(carry, j): + u, p, q = carry + idx = N - 2 - j + u = u.at[1:N - 1, idx].set(p[1:N - 1, idx] * u[1:N - 1, idx + 1] + q[1:N - 1, idx]) + return (u, p, q), None + + def time_step_body(carry, t): + u, v, p, q = carry + + v = v.at[0, 1:N - 1].set(1.0) + p = p.at[1:N - 1, 0].set(0.0) + q = q.at[1:N - 1, 0].set(v[0, 1:N - 1]) + (p, q, u), _ = lax.scan(first_j_scan, (p, q, u), jnp.arange(1, N - 1)) + + v = v.at[N - 1, 1:N - 1].set(1.0) + + (v, p, q), _ = lax.scan(first_backward_j_scan, (v, p, q), jnp.arange(0, N - 2)) + + u = u.at[1:N - 1, 0].set(1.0) + p = p.at[1:N - 1, 0].set(0.0) + q = q.at[1:N - 1, 0].set(u[1:N - 1, 0]) + (p, q, v), _ = lax.scan(second_j_scan, (p, q, v), jnp.arange(1, N - 1)) + u = u.at[1:N - 1, N - 1].set(1.0) + (u, p, q), _ = lax.scan(second_backward_j_scan, (u, p, q), jnp.arange(0, N - 2)) + + return (u, v, p, q), None + + (u, v, p, q), _ = lax.scan(time_step_body, (u, v, p, q), jnp.arange(0, TSTEPS)) + return jnp.sum(u) + + def run_adi(device_type: dace.dtypes.DeviceType): ''' Runs ADI for the given device @@ -156,6 +231,46 @@ def run_adi(device_type: dace.dtypes.DeviceType): return sdfg +def run_adi_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size for smaller problem) + _, N = sizes["mini"] + + # Use smaller number of timesteps to avoid exploding gradients + TSTEPS = 10 + + u = initialize(N) + + # Initialize gradient computation data + gradient_u = np.zeros_like(u) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dace.program + def autodiff_kernel(TSTEPS: dace.int64, u: dace.float64[N, N]): + adi_kernel(TSTEPS, u) + return np.sum(u) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["u"], outputs=["__return"]) + sdfg(TSTEPS, u, N=N, gradient_u=gradient_u, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, u: adi_jax_kernel(jnp, lax, TSTEPS, u) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + u_jax = np.copy(initialize(N)) + jax_grad_u = jax_grad(TSTEPS, u_jax) + + np.testing.assert_allclose(gradient_u, jax_grad_u, rtol=1e-6) + + def test_cpu(): run_adi(dace.dtypes.DeviceType.CPU) @@ -165,6 +280,12 @@ def test_gpu(): run_adi(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_adi_autodiff() + + @pytest.mark.skip(reason="Intel FPGA argument overflow") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -181,6 +302,7 @@ def test_fpga(): if target == "cpu": run_adi(dace.dtypes.DeviceType.CPU) + run_adi_autodiff() elif target == "gpu": run_adi(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/atax_test.py b/tests/npbench/polybench/atax_test.py index dc0a438fab..0baeb48137 100644 --- a/tests/npbench/polybench/atax_test.py +++ b/tests/npbench/polybench/atax_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -42,6 +43,11 @@ def init_data(M, N): return A, x, y +def atax_jax_kernel(jnp, A, x): + B = (A @ x) @ A + return jnp.sum(B) + + def run_atax(device_type: dace.dtypes.DeviceType): """ Runs ATAX for the given device @@ -91,6 +97,36 @@ def run_atax(device_type: dace.dtypes.DeviceType): return sdfg +def run_atax_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A, x, y = init_data(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[M, N], x: dc.float32[N]): + y = kernel(A, x) + return np.sum(y) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, x, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda A, x: atax_jax_kernel(jnp, A, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, x) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-6) + + def test_cpu(): run_atax(dace.dtypes.DeviceType.CPU) @@ -100,6 +136,12 @@ def test_gpu(): run_atax(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_atax_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_atax(dace.dtypes.DeviceType.FPGA) @@ -121,6 +163,7 @@ def test_xilinx_decoupled_array_interfaces(): if target == "cpu": run_atax(dace.dtypes.DeviceType.CPU) + run_atax_autodiff() elif target == "gpu": run_atax(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/bicg_test.py b/tests/npbench/polybench/bicg_test.py index ae7daa10ef..9c6d8b1261 100644 --- a/tests/npbench/polybench/bicg_test.py +++ b/tests/npbench/polybench/bicg_test.py @@ -9,6 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -36,6 +37,11 @@ def bicg_kernel(A: dc.float64[N, M], p: dc.float64[M], r: dc.float64[N]): return r @ A, A @ p +def bicg_jax_kernel(jnp, A, p, r): + B, D = r @ A, A @ p + return jnp.sum(D) + + def run_bicg(device_type: dace.dtypes.DeviceType): ''' Runs BiCG for the given device @@ -87,6 +93,41 @@ def run_bicg(device_type: dace.dtypes.DeviceType): return sdfg +def run_bicg_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A, p, r = initialize(M, N) + + # Initialize gradient computation data + B = np.zeros((M, ), dtype=np.float64) + D = np.zeros((N, ), dtype=np.float64) + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[N, M], p: dc.float64[M], r: dc.float64[N]): + B, D = bicg_kernel(A, p, r) + return np.sum(D) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, p, r, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, p, r: bicg_jax_kernel(jnp, A, p, r) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, p, r) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_bicg(dace.dtypes.DeviceType.CPU) @@ -96,6 +137,12 @@ def test_gpu(): run_bicg(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_bicg_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_bicg(dace.dtypes.DeviceType.FPGA) @@ -111,6 +158,7 @@ def test_fpga(): if target == "cpu": run_bicg(dace.dtypes.DeviceType.CPU) + run_bicg_autodiff() elif target == "gpu": run_bicg(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/cholesky_test.py b/tests/npbench/polybench/cholesky_test.py index a83d153338..c3c12ca158 100644 --- a/tests/npbench/polybench/cholesky_test.py +++ b/tests/npbench/polybench/cholesky_test.py @@ -10,6 +10,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -43,6 +44,36 @@ def init_data(N): return A +def cholesky_jax_kernel(jnp, lax, A): + A = A.at[0, 0].set(jnp.sqrt(A[0, 0])) + + def row_update_body(A, i): + + def col_update_body(A, j): + + def do_update(_): + mask = jnp.arange(A.shape[1]) < j + A_i_slice = jnp.where(mask, A[i, :], 0) + A_j_slice = jnp.where(mask, A[j, :], 0) + dot_product = jnp.dot(A_i_slice, A_j_slice) + new_val = (A[i, j] - dot_product) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, do_update, lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(col_update_body, A, jnp.arange(A.shape[0])) + + mask = jnp.arange(A.shape[1]) < i + A_i_slice = jnp.where(mask, A[i, :], 0) + dot_product = jnp.dot(A_i_slice, A_i_slice) + A = A.at[i, i].set(jnp.sqrt(A[i, i] - dot_product)) + return A, None + + A, _ = lax.scan(row_update_body, A, jnp.arange(1, A.shape[0])) + return jnp.sum(A) + + def ground_truth(N, A): A[0, 0] = np.sqrt(A[0, 0]) for i in range(1, N): @@ -94,6 +125,38 @@ def run_cholesky(device_type: dace.dtypes.DeviceType): return sdfg +def run_cholesky_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = 20 + A = init_data(N) + A_jax = jnp.copy(A) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[N, N]): + kernel(A) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda A: cholesky_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-4, atol=1e-4) + + def test_cpu(): run_cholesky(dace.dtypes.DeviceType.CPU) @@ -103,6 +166,12 @@ def test_gpu(): run_cholesky(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_cholesky_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_cholesky(dace.dtypes.DeviceType.FPGA) @@ -118,6 +187,7 @@ def test_fpga(): if target == "cpu": run_cholesky(dace.dtypes.DeviceType.CPU) + run_cholesky_autodiff() elif target == "gpu": run_cholesky(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/correlation_test.py b/tests/npbench/polybench/correlation_test.py index a5532cf829..9658408c75 100644 --- a/tests/npbench/polybench/correlation_test.py +++ b/tests/npbench/polybench/correlation_test.py @@ -7,6 +7,7 @@ import pytest import argparse from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -42,6 +43,20 @@ def initialize(M, N, datatype=np.float64): return float_n, data +def correlation_jax_kernel(jnp, float_n, data): + mean = jnp.mean(data, axis=0) + M = data.shape[1] + stddev = jnp.sqrt(jnp.mean(jnp.subtract(data, mean)**2, axis=0)) + stddev = jnp.where(stddev <= 0.1, 1.0, stddev) + data = jnp.subtract(data, mean) + data = jnp.divide(data, jnp.sqrt(float_n) * stddev) + corr = jnp.eye(M, dtype=data.dtype) + for i in range(M - 1): + corr = corr.at[i, i + 1:M].set(data[:, i] @ data[:, i + 1:M]) + corr = corr.at[i + 1:M, i].set(corr[i, i + 1:M]) + return jnp.sum(corr) + + def ground_truth(M, float_n, data): mean = np.mean(data, axis=0) @@ -88,6 +103,40 @@ def run_correlation(device_type: dace.dtypes.DeviceType): return sdfg +def run_correlation_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + float_n, data = initialize(M, N) + + # Initialize gradient computation data + gradient_data = np.zeros_like(data) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(float_n: dc.float64, data: dc.float64[N, M]): + corr = correlation_kernel(float_n, data) + return np.sum(corr) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["data"], outputs=["__return"]) + sdfg(float_n, data, M=M, N=N, gradient_data=gradient_data, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda float_n, data: correlation_jax_kernel(jnp, float_n, data) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + _, data_jax = initialize(M, N) + jax_grad_data = jax_grad(float_n, data_jax) + np.testing.assert_allclose(gradient_data, jax_grad_data, rtol=1e-8, atol=1e-8) + + def test_cpu(): run_correlation(dace.dtypes.DeviceType.CPU) @@ -97,6 +146,17 @@ def test_gpu(): run_correlation(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_correlation_autodiff() + os.environ['DACE_testing_serialization'] = last_value + + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -107,6 +167,7 @@ def test_gpu(): if target == "cpu": run_correlation(dace.dtypes.DeviceType.CPU) + run_correlation_autodiff() elif target == "gpu": run_correlation(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/covariance_test.py b/tests/npbench/polybench/covariance_test.py index 66878bad1a..8e89a02577 100644 --- a/tests/npbench/polybench/covariance_test.py +++ b/tests/npbench/polybench/covariance_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc @@ -12,6 +13,7 @@ from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.libraries.standard import Reduce from dace.libraries.blas import Gemv +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -36,6 +38,17 @@ def covariance_kernel(float_n: dc.float32, data: dc.float32[N, M]): return cov +def covariance_jax_kernel(jnp, float_n, data): + mean = jnp.mean(data, axis=0) + M = data.shape[1] + data -= mean + cov = jnp.zeros((M, M), dtype=data.dtype) + for i in range(M): + cov = cov.at[i:M, i].set(data[:, i] @ data[:, i:M] / (float_n - 1.0)) + cov = cov.at[i, i:M].set(data[:, i] @ data[:, i:M] / (float_n - 1.0)) + return jnp.sum(cov) + + def ground_truth(M, N, float_n, data): mean = np.empty((M, ), dtype=data.dtype) @@ -123,6 +136,37 @@ def run_covariance(device_type: dace.dtypes.DeviceType): return sdfg +def run_covariance_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + float_n, data = init_data(M, N) + data_jax = np.copy(data) + + # Initialize gradient computation data + gradient_data = np.zeros_like(data) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(float_n: dc.float32, data: dc.float32[N, M]): + cov = covariance_kernel(float_n, data) + return np.sum(cov) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["data"], outputs=["__return"]) + sdfg(float_n, data, M=M, N=N, gradient_data=gradient_data, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda float_n, data: covariance_jax_kernel(jnp, float_n, data) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + jax_grad_data = jax_grad(float_n, data_jax) + np.testing.assert_allclose(gradient_data, jax_grad_data, rtol=1e-5, atol=1e-8) + + def test_cpu(monkeypatch): # Serialization causes issues, we temporarily disable it monkeypatch.setenv("DACE_testing_serialization", 0) @@ -134,6 +178,17 @@ def test_gpu(): run_covariance(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_covariance_autodiff() + os.environ['DACE_testing_serialization'] = last_value + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_covariance(dace.dtypes.DeviceType.FPGA) @@ -149,6 +204,7 @@ def test_fpga(): if target == "cpu": run_covariance(dace.dtypes.DeviceType.CPU) + run_covariance_autodiff() elif target == "gpu": run_covariance(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/deriche_test.py b/tests/npbench/polybench/deriche_test.py index b2fe7d47e2..6aff7df937 100644 --- a/tests/npbench/polybench/deriche_test.py +++ b/tests/npbench/polybench/deriche_test.py @@ -9,6 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.autodiff import add_backward_pass # Data set sizes # W, H @@ -76,6 +77,64 @@ def initialize(W, H, datatype=np.float64): return alpha, imgIn +def deriche_jax_kernel(jnp, lax, alpha, imgIn): + + k = (1.0 - jnp.exp(-alpha))**2 / (1.0 + alpha * jnp.exp(-alpha) - jnp.exp(2.0 * alpha)) + a1 = a5 = k + a2 = a6 = k * jnp.exp(-alpha) * (alpha - 1.0) + a3 = a7 = k * jnp.exp(-alpha) * (alpha + 1.0) + a4 = a8 = -k * jnp.exp(-2.0 * alpha) + b1 = 2.0**(-alpha) + b2 = -jnp.exp(-2.0 * alpha) + c1 = c2 = 1 + + y1 = jnp.empty_like(imgIn) + y1 = y1.at[:, 0].set(a1 * imgIn[:, 0]) + y1 = y1.at[:, 1].set(a1 * imgIn[:, 1] + a2 * imgIn[:, 0] + b1 * y1[:, 0]) + + def horizontal_forward_body(y1, j): + new_y1 = y1.at[:, j].set(a1 * imgIn[:, j] + a2 * imgIn[:, j - 1] + b1 * y1[:, j - 1] + b2 * y1[:, j - 2]) + return new_y1, None + + y1, _ = lax.scan(horizontal_forward_body, y1, jnp.arange(2, imgIn.shape[1])) + + y2 = jnp.empty_like(imgIn) + y2 = y2.at[:, -1].set(0.0) + y2 = y2.at[:, -2].set(a3 * imgIn[:, -1]) + + def horizontal_backward_body(y2, j): + idx = imgIn.shape[1] - 3 - j + new_y2 = y2.at[:, idx].set(a3 * imgIn[:, idx + 1] + a4 * imgIn[:, idx + 2] + b1 * y2[:, idx + 1] + + b2 * y2[:, idx + 2]) + return new_y2, None + + y2, _ = lax.scan(horizontal_backward_body, y2, jnp.arange(0, imgIn.shape[1] - 2)) + + imgOut = c1 * (y1 + y2) + + y1 = y1.at[0, :].set(a5 * imgOut[0, :]) + y1 = y1.at[1, :].set(a5 * imgOut[1, :] + a6 * imgOut[0, :] + b1 * y1[0, :]) + + def vertical_forward_body(y1, i): + new_y1 = y1.at[i, :].set(a5 * imgOut[i, :] + a6 * imgOut[i - 1, :] + b1 * y1[i - 1, :] + b2 * y1[i - 2, :]) + return new_y1, None + + y1, _ = lax.scan(vertical_forward_body, y1, jnp.arange(2, imgIn.shape[0])) + + y2 = y2.at[-1, :].set(0.0) + y2 = y2.at[-2, :].set(a7 * imgOut[-1, :]) + + def vertical_backward_body(y2, i): + idx = imgIn.shape[0] - 3 - i + new_y2 = y2.at[idx, :].set(a7 * imgOut[idx + 1, :] + a8 * imgOut[idx + 2, :] + b1 * y2[idx + 1, :] + + b2 * y2[idx + 2, :]) + return new_y2, None + + y2, _ = lax.scan(vertical_backward_body, y2, jnp.arange(0, imgIn.shape[0] - 2)) + + return jnp.sum(c2 * (y1 + y2)) + + def ground_truth(alpha, imgIn): k = (1.0 - np.exp(-alpha)) * (1.0 - np.exp(-alpha)) / (1.0 + alpha * np.exp(-alpha) - np.exp(2.0 * alpha)) @@ -156,6 +215,41 @@ def run_deriche(device_type: dace.dtypes.DeviceType): return sdfg +def run_deriche_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + W, H = sizes["mini"] + alpha, imgIn = initialize(W, H) + + # Initialize gradient computation data + gradient_imgIn = np.zeros_like(imgIn) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output using __return pattern + @dc.program + def autodiff_kernel(alpha: dc.float64, imgIn: dc.float64[W, H]): + imgOut = deriche_kernel(alpha, imgIn) + return np.sum(imgOut) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["imgIn"], outputs=["__return"]) + sdfg(alpha, imgIn, W=W, H=H, gradient_imgIn=gradient_imgIn, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, imgIn: deriche_jax_kernel(jnp, lax, alpha, imgIn) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1)) + alpha_jax, imgIn_jax = initialize(W, H) + jax_grad_imgIn = jax_grad(alpha_jax, imgIn_jax) + np.testing.assert_allclose(gradient_imgIn, jax_grad_imgIn) + + def test_cpu(): run_deriche(dace.dtypes.DeviceType.CPU) @@ -165,6 +259,12 @@ def test_gpu(): run_deriche(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_deriche_autodiff() + + @fpga_test(assert_ii_1=False, intel=False) def test_fpga(): return run_deriche(dace.dtypes.DeviceType.FPGA) @@ -180,6 +280,7 @@ def test_fpga(): if target == "cpu": run_deriche(dace.dtypes.DeviceType.CPU) + run_deriche_autodiff() elif target == "gpu": run_deriche(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/doitgen_test.py b/tests/npbench/polybench/doitgen_test.py index 52fffd1d0d..c8202929fa 100644 --- a/tests/npbench/polybench/doitgen_test.py +++ b/tests/npbench/polybench/doitgen_test.py @@ -8,6 +8,7 @@ from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # NQ, NR, NP @@ -38,6 +39,15 @@ def initialize(NR, NQ, NP, datatype=np.float64): return A, C4 +def doitgen_jax_kernel(jnp, A, C4): + NR = A.shape[0] + NQ = A.shape[1] + NP = A.shape[2] + for r in range(NR): + A = A.at[r, :, :].set(jnp.reshape(jnp.reshape(A[r], (NQ, NP)) @ C4, (NQ, NP))) + return jnp.sum(A) + + def ground_truth(NR, NQ, NP, A, C4): A[:] = np.reshape(np.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP)) @@ -83,6 +93,40 @@ def run_doitgen(device_type: dace.dtypes.DeviceType): return sdfg +def run_doitgen_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + NQ, NR, NP = sizes["mini"] + A, C4 = initialize(NR, NQ, NP) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[NR, NQ, NP], C4: dc.float64[NP, NP]): + doitgen_kernel(A, C4) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, C4, NR=NR, NQ=NQ, NP=NP, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, C4: doitgen_jax_kernel(jnp, A, C4) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + A_jax, C4_jax = initialize(NR, NQ, NP) + jax_grad_A = jax_grad(A_jax, C4_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_doitgen(dace.dtypes.DeviceType.CPU) @@ -92,6 +136,12 @@ def test_gpu(): run_doitgen(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_doitgen_autodiff() + + @pytest.mark.skip(reason="long long support for IntelFPGA") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -108,6 +158,7 @@ def test_fpga(): if target == "cpu": run_doitgen(dace.dtypes.DeviceType.CPU) + run_doitgen_autodiff() elif target == "gpu": run_doitgen(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/durbin_test.py b/tests/npbench/polybench/durbin_test.py index ffeff150d9..890416b7e2 100644 --- a/tests/npbench/polybench/durbin_test.py +++ b/tests/npbench/polybench/durbin_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -63,6 +64,46 @@ def ground_truth(r): return y +def durbin_jax_kernel(jnp, lax, r): + # Initialize y, alpha, and beta. + y = jnp.empty_like(r) + alpha = -r[0] + beta = 1.0 + y = y.at[0].set(-r[0]) + + # Define the scan body. The loop index k will run from 1 to r.shape[0]-1. + def scan_body(carry, k): + alpha, beta, y, r = carry + + # Update beta. + beta = beta * (1.0 - alpha * alpha) + + # Create a mask for indices less than k. + mask = jnp.arange(r.shape[0]) < k + + # Compute the dot product between y and a shifted version of r. + # Note: jnp.roll(jnp.flip(r), [k], 0) is equivalent to shifting along axis 0. + products = jnp.where(mask, y * jnp.roll(jnp.flip(r), k, axis=0), 0.0) + dot_prod = jnp.sum(products) + + # Update alpha based on the k-th element of r and the dot product. + alpha = -(r[k] + dot_prod) / beta + + # Compute an update slice from a shifted version of y. + y_update_slice = jnp.where(mask, jnp.roll(jnp.flip(y), k, axis=0) * alpha, 0.0) + + # Update y by adding the computed slice and setting the k-th element to alpha. + y = y + y_update_slice + y = y.at[k].set(alpha) + + return (alpha, beta, y, r), None + + # Run the scan from k = 1 to r.shape[0]-1. + (alpha, beta, y, r), _ = lax.scan(scan_body, (alpha, beta, y, r), jnp.arange(1, r.shape[0])) + + return jnp.sum(y) + + def run_durbin(device_type: dace.dtypes.DeviceType): ''' Runs Durbin for the given device @@ -98,6 +139,41 @@ def run_durbin(device_type: dace.dtypes.DeviceType): return sdfg +def run_durbin_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench small size) + N = sizes["small"] + r = initialize(N) + + # Initialize gradient computation data + gradient_r = np.zeros_like(r) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(r: dc.float64[N]): + y = durbin_kernel(r) + return np.sum(y) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg(simplify=True) + add_backward_pass(sdfg=sdfg, inputs=["r"], outputs=["__return"], simplify=False) + sdfg(r, N=N, gradient_r=gradient_r, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda r: durbin_jax_kernel(jnp, lax, r) + jax_grad = jax.jit(jax.grad(jax_kernel)) + r_jax = initialize(N) + jax_grad_r = jax_grad(r_jax) + np.testing.assert_allclose(gradient_r, jax_grad_r) + + def test_cpu(): run_durbin(dace.dtypes.DeviceType.CPU) @@ -107,6 +183,12 @@ def test_gpu(): run_durbin(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_durbin_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_durbin(dace.dtypes.DeviceType.FPGA) @@ -122,6 +204,7 @@ def test_fpga(): if target == "cpu": run_durbin(dace.dtypes.DeviceType.CPU) + run_durbin_autodiff() elif target == "gpu": run_durbin(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/fdtd_2d_test.py b/tests/npbench/polybench/fdtd_2d_test.py index 37db5de743..0ad83a4e9c 100644 --- a/tests/npbench/polybench/fdtd_2d_test.py +++ b/tests/npbench/polybench/fdtd_2d_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition, MapFusionVertical from dace.transformation.auto.auto_optimize import auto_optimize import argparse +from dace.autodiff import add_backward_pass # Data set sizes # TMAX, NX, NY @@ -50,6 +51,26 @@ def init_data(TMAX, NX, NY): return ex, ey, hz, _fict_ +def fdtd_2d_jax_kernel(jnp, lax, ex, ey, hz, _fict_): + """JAX implementation using efficient lax.scan operations""" + TMAX = _fict_.shape[0] + + def scan_body(carry, t): + ex, ey, hz = carry + # Set the top row of ey using _fict_ for the current time step. + ey = ey.at[0, :].set(_fict_[t]) + # Update ey for rows 1 and beyond. + ey = ey.at[1:, :].set(ey[1:, :] - 0.5 * (hz[1:, :] - hz[:-1, :])) + # Update ex for columns 1 and beyond. + ex = ex.at[:, 1:].set(ex[:, 1:] - 0.5 * (hz[:, 1:] - hz[:, :-1])) + # Update hz for the interior (all but last row and col). + hz = hz.at[:-1, :-1].set(hz[:-1, :-1] - 0.7 * ((ex[:-1, 1:] - ex[:-1, :-1]) + (ey[1:, :-1] - ey[:-1, :-1]))) + return (ex, ey, hz), None + + (ex, ey, hz), _ = lax.scan(scan_body, (ex, ey, hz), jnp.arange(TMAX)) + return jnp.sum(hz) + + def ground_truth(TMAX, NX, NY, ex, ey, hz, _fict_): for t in range(TMAX): @@ -113,6 +134,39 @@ def run_fdtd_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_fdtd_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + TMAX, NX, NY = (2, 10, 12) + ex, ey, hz, _fict_ = init_data(TMAX, NX, NY) + + # Initialize gradient computation data + gradient_ex = np.zeros_like(ex) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output using __return pattern + @dc.program + def fdtd_2d_autodiff_kernel(ex: dc.float32[NX, NY], ey: dc.float32[NX, NY], hz: dc.float32[NX, NY], + _fict_: dc.float32[TMAX]): + kernel(ex, ey, hz, _fict_) + return np.sum(hz) + + # Add the backward pass to the SDFG + sdfg = fdtd_2d_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["ex"], outputs=["__return"]) + sdfg(ex, ey, hz, _fict_, TMAX=TMAX, NX=NX, NY=NY, gradient_ex=gradient_ex, gradient___return=gradient___return) + + # Numerically validate vs JAX (use float32 consistent with kernel) + jax_kernel = lambda ex, ey, hz, _fict_: fdtd_2d_jax_kernel(jnp, lax, ex, ey, hz, _fict_) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + ex_jax, ey_jax, hz_jax, _fict_jax = init_data(TMAX, NX, NY) + jax_grad_ex = jax_grad(ex_jax, ey_jax, hz_jax, _fict_jax) + np.testing.assert_allclose(gradient_ex, jax_grad_ex) + + def test_cpu(): run_fdtd_2d(dace.dtypes.DeviceType.CPU) @@ -122,6 +176,12 @@ def test_gpu(): run_fdtd_2d(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_fdtd_2d_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_fdtd_2d(dace.dtypes.DeviceType.FPGA) @@ -137,6 +197,7 @@ def test_fpga(): if target == "cpu": run_fdtd_2d(dace.dtypes.DeviceType.CPU) + run_fdtd_2d_autodiff() elif target == "gpu": run_fdtd_2d(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/gemm_npbench_test.py b/tests/npbench/polybench/gemm_npbench_test.py index 58948f295d..266a44b2ee 100644 --- a/tests/npbench/polybench/gemm_npbench_test.py +++ b/tests/npbench/polybench/gemm_npbench_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK @@ -40,6 +40,10 @@ def initialize(NI, NJ, NK, datatype=np.float64): return alpha, beta, C, A, B +def gemm_jax_kernel(jnp, alpha, beta, A, B, C): + return jnp.sum(alpha * A @ B + beta * C) + + def run_gemm(device_type: dace.dtypes.DeviceType): ''' Runs Gemm for the given device @@ -76,6 +80,40 @@ def run_gemm(device_type: dace.dtypes.DeviceType): return sdfg +def run_gemm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + NI, NJ, NK = sizes["mini"] + alpha, beta, C, A, B = initialize(NI, NJ, NK) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[NI, NJ], A: dc.float64[NI, NK], + B: dc.float64[NK, NJ]): + gemm_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, NI=NI, NJ=NJ, NK=NK, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, C: gemm_jax_kernel(jnp, alpha, beta, A, B, C) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2), static_argnums=(0, 1)) + jax_grad_A = jax_grad(alpha, beta, A, B, C) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gemm(dace.dtypes.DeviceType.CPU) @@ -85,6 +123,12 @@ def test_gpu(): run_gemm(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gemm_autodiff() + + @fpga_test(assert_ii_1=False, xilinx=False) def test_fpga(): return run_gemm(dace.dtypes.DeviceType.FPGA) @@ -100,6 +144,7 @@ def test_fpga(): if target == "cpu": run_gemm(dace.dtypes.DeviceType.CPU) + run_gemm_autodiff() elif target == "gpu": run_gemm(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/gemver_test.py b/tests/npbench/polybench/gemver_test.py index 58e078fe11..7b66df7c5b 100644 --- a/tests/npbench/polybench/gemver_test.py +++ b/tests/npbench/polybench/gemver_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test +from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -45,6 +44,13 @@ def initialize(N, datatype=np.float64): return alpha, beta, A, u1, v1, u2, v2, w, x, y, z +def gemver_jax_kernel(jnp, alpha, beta, A, u1, v1, u2, v2, w, x, y, z): + A += jnp.outer(u1, v1) + jnp.outer(u2, v2) + x += beta * y @ A + z + w += alpha * A @ x + return jnp.sum(w) + + def run_gemver(device_type: dace.dtypes.DeviceType): ''' Runs Gemver for the given device @@ -86,6 +92,56 @@ def run_gemver(device_type: dace.dtypes.DeviceType): return sdfg +def run_gemver_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + alpha, beta, A, u1, v1, u2, v2, w, x, y, z = initialize(N) + A_jax, u1_jax, v1_jax, u2_jax, v2_jax, w_jax, x_jax, y_jax, z_jax = map(np.copy, (A, u1, v1, u2, v2, w, x, y, z)) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[N, N], u1: dc.float64[N], v1: dc.float64[N], + u2: dc.float64[N], v2: dc.float64[N], w: dc.float64[N], x: dc.float64[N], y: dc.float64[N], + z: dc.float64[N]): + gemver_kernel(alpha, beta, A, u1, v1, u2, v2, w, x, y, z) + return np.sum(w) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, + beta, + A, + np.copy(u1), + v1, + u2, + v2, + w, + x, + y, + z, + N=N, + gradient_A=gradient_A, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, u1, v1, u2, v2, w, x, y, z: gemver_jax_kernel( + jnp, alpha, beta, A, u1, v1, u2, v2, w, x, y, z) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2)) + jax_grad_A = jax_grad(alpha, beta, A_jax, u1_jax, v1_jax, u2_jax, v2_jax, w_jax, x_jax, y_jax, z_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gemver(dace.dtypes.DeviceType.CPU) @@ -95,6 +151,12 @@ def test_gpu(): run_gemver(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gemver_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_gemver(dace.dtypes.DeviceType.FPGA) @@ -110,6 +172,7 @@ def test_fpga(): if target == "cpu": run_gemver(dace.dtypes.DeviceType.CPU) + run_gemver_autodiff() elif target == "gpu": run_gemver(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/gesummv_test.py b/tests/npbench/polybench/gesummv_test.py index 9ba11a6b44..75dde5c480 100644 --- a/tests/npbench/polybench/gesummv_test.py +++ b/tests/npbench/polybench/gesummv_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -34,6 +34,10 @@ def initialize(N, datatype=np.float64): return alpha, beta, A, B, x +def gesummv_jax_kernel(jnp, alpha, beta, A, B, x): + return jnp.sum(alpha * (A @ x) + beta * (B @ x)) + + def run_gesummv(device_type: dace.dtypes.DeviceType): ''' Runs Gesummv for the given device @@ -69,6 +73,40 @@ def run_gesummv(device_type: dace.dtypes.DeviceType): return sdfg +def run_gesummv_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + alpha, beta, A, B, x = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[N, N], B: dc.float64[N, N], + x: dc.float64[N]): + C = gesummv_kernel(alpha, beta, A, B, x) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, A, B, x, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, x: gesummv_jax_kernel(jnp, alpha, beta, A, B, x) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2), static_argnums=(0, 1)) + jax_grad_A = jax_grad(alpha, beta, A, B, x) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gesummv(dace.dtypes.DeviceType.CPU) @@ -78,6 +116,12 @@ def test_gpu(): run_gesummv(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gesummv_autodiff() + + @pytest.mark.skip(reason="Xilinx synthesis fails") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -94,6 +138,7 @@ def test_fpga(): if target == "cpu": run_gesummv(dace.dtypes.DeviceType.CPU) + run_gesummv_autodiff() elif target == "gpu": run_gesummv(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/gramschmidt_test.py b/tests/npbench/polybench/gramschmidt_test.py index 5217db86f8..278e8169b6 100644 --- a/tests/npbench/polybench/gramschmidt_test.py +++ b/tests/npbench/polybench/gramschmidt_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -47,6 +48,39 @@ def initialize(M, N, datatype=np.float64): return A +def gramschmidt_jax_kernel(jnp, lax, A): + n = A.shape[1] + Q = jnp.zeros_like(A) + R = jnp.zeros((n, n), dtype=A.dtype) + + def body_fun(carry, k): + Q, R, A = carry + + nrm = jnp.dot(A[:, k], A[:, k]) + R = R.at[k, k].set(jnp.sqrt(nrm)) + Q = Q.at[:, k].set(A[:, k] / R[k, k]) + + def inner_body_fun(carry_inner, j): + Q, R, A = carry_inner + + def do_update(_): + new_R = R.at[k, j].set(jnp.dot(Q[:, k], A[:, j])) + new_A = A.at[:, j].add(-Q[:, k] * new_R[k, j]) + return (Q, new_R, new_A) + + def no_update(_): + return (Q, R, A) + + Q, R, A = lax.cond(j >= (k + 1), do_update, no_update, operand=None) + return (Q, R, A), None + + (Q, R, A), _ = lax.scan(inner_body_fun, (Q, R, A), jnp.arange(n)) + return (Q, R, A), None + + (Q, R, A), _ = lax.scan(body_fun, (Q, R, A), jnp.arange(n)) + return jnp.sum(A) + + def ground_truth(A): Q = np.zeros_like(A) @@ -100,6 +134,41 @@ def run_gramschmidt(device_type: dace.dtypes.DeviceType): return sdfg +def run_gramschmidt_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + A = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[M, N]): + Q, R = gramschmidt_kernel(A) + return np.sum(A) # Sum the modified A matrix + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A: gramschmidt_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + A_jax = initialize(M, N) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_gramschmidt(dace.dtypes.DeviceType.CPU) @@ -109,6 +178,12 @@ def test_gpu(): run_gramschmidt(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_gramschmidt_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_gramschmidt(dace.dtypes.DeviceType.FPGA) @@ -124,6 +199,7 @@ def test_fpga(): if target == "cpu": run_gramschmidt(dace.dtypes.DeviceType.CPU) + run_gramschmidt_autodiff() elif target == "gpu": run_gramschmidt(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/heat_3d_test.py b/tests/npbench/polybench/heat_3d_test.py index 75ad902c4b..7be17976ee 100644 --- a/tests/npbench/polybench/heat_3d_test.py +++ b/tests/npbench/polybench/heat_3d_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -39,6 +40,29 @@ def initialize(N, datatype=np.float64): return A, B +def heat_3d_jax_kernel(jnp, lax, TSTEPS, A, B): + + def time_step(carry, t): + A, B = carry + B_new = B.at[1:-1, 1:-1, + 1:-1].set(0.125 * (A[2:, 1:-1, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[:-2, 1:-1, 1:-1]) + 0.125 * + (A[1:-1, 2:, 1:-1] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, :-2, 1:-1]) + 0.125 * + (A[1:-1, 1:-1, 2:] - 2.0 * A[1:-1, 1:-1, 1:-1] + A[1:-1, 1:-1, :-2]) + + A[1:-1, 1:-1, 1:-1]) + A_new = A.at[1:-1, 1:-1, + 1:-1].set(0.125 * + (B_new[2:, 1:-1, 1:-1] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[:-2, 1:-1, 1:-1]) + + 0.125 * + (B_new[1:-1, 2:, 1:-1] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[1:-1, :-2, 1:-1]) + + 0.125 * + (B_new[1:-1, 1:-1, 2:] - 2.0 * B_new[1:-1, 1:-1, 1:-1] + B_new[1:-1, 1:-1, :-2]) + + B_new[1:-1, 1:-1, 1:-1]) + return (A_new, B_new), None + + (A_final, B_final), _ = lax.scan(time_step, (A, B), jnp.arange(1, TSTEPS)) + return jnp.sum(A_final) + + def ground_truth(TSTEPS, A, B): for t in range(1, TSTEPS): @@ -101,6 +125,41 @@ def count_maps(sdfg: dc.SDFG) -> int: return sdfg +def run_heat_3d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench small size) + TSTEPS, N = sizes["small"] + A, B = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N, N, N], B: dc.float64[N, N, N]): + heat_3d_kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: heat_3d_jax_kernel(jnp, lax, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + A_jax, B_jax = initialize(N) + jax_grad_A = jax_grad(TSTEPS, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_heat_3d(dace.dtypes.DeviceType.CPU) @@ -110,6 +169,12 @@ def test_gpu(): run_heat_3d(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_heat_3d_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_heat_3d(dace.dtypes.DeviceType.FPGA) @@ -125,6 +190,7 @@ def test_fpga(): if target == "cpu": run_heat_3d(dace.dtypes.DeviceType.CPU) + run_heat_3d_autodiff() elif target == "gpu": run_heat_3d(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/jacobi_1d_test.py b/tests/npbench/polybench/jacobi_1d_test.py index 61f7ba1211..97a8fbb680 100644 --- a/tests/npbench/polybench/jacobi_1d_test.py +++ b/tests/npbench/polybench/jacobi_1d_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -25,6 +25,15 @@ def jacobi_1d_kernel(TSTEPS: dc.int64, A: dc.float64[N], B: dc.float64[N]): A[1:-1] = 0.33333 * (B[:-2] + B[1:-1] + B[2:]) +def jacobi_1d_jax_kernel(jax, jnp, TSTEPS, A, B): + + for t in range(1, TSTEPS): + B = B.at[1:-1].set(0.33333 * (A[:-2] + A[1:-1] + A[2:])) + A = A.at[1:-1].set(0.33333 * (B[:-2] + B[1:-1] + B[2:])) + + return jax.block_until_ready(jnp.sum(A)) + + def initialize(N, datatype=np.float64): A = np.fromfunction(lambda i: (i + 2) / N, (N, ), dtype=datatype) B = np.fromfunction(lambda i: (i + 3) / N, (N, ), dtype=datatype) @@ -76,6 +85,40 @@ def run_jacobi_1d(device_type: dace.dtypes.DeviceType): return sdfg +def run_jacobi_1d_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + TSTEPS, N = (20, 30) + A, B = initialize(N) + jax_A, jax_B = np.copy(A), np.copy(B) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N], B: dc.float64[N]): + jacobi_1d_kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: jacobi_1d_jax_kernel(jax, jnp, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + jax_grad_A = jax_grad(TSTEPS, jax_A, jax_B) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_jacobi_1d(dace.dtypes.DeviceType.CPU) @@ -85,6 +128,12 @@ def test_gpu(): run_jacobi_1d(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_jacobi_1d_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_jacobi_1d(dace.dtypes.DeviceType.FPGA) @@ -100,6 +149,7 @@ def test_fpga(): if target == "cpu": run_jacobi_1d(dace.dtypes.DeviceType.CPU) + run_jacobi_1d_autodiff() elif target == "gpu": run_jacobi_1d(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/jacobi_2d_test.py b/tests/npbench/polybench/jacobi_2d_test.py index 58ceb4c365..b4f908f3ac 100644 --- a/tests/npbench/polybench/jacobi_2d_test.py +++ b/tests/npbench/polybench/jacobi_2d_test.py @@ -9,6 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int32) @@ -21,6 +22,20 @@ def kernel(TSTEPS: dc.int32, A: dc.float32[N, N], B: dc.float32[N, N]): A[1:-1, 1:-1] = 0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + B[2:, 1:-1] + B[:-2, 1:-1]) +def kernel_jax(jnp, lax, TSTEPS, A, B): + + def body_fn(carry, t): + A, B = carry + + B = B.at[1:-1, 1:-1].set(0.2 * (A[1:-1, 1:-1] + A[1:-1, :-2] + A[1:-1, 2:] + A[2:, 1:-1] + A[:-2, 1:-1])) + + A = A.at[1:-1, 1:-1].set(0.2 * (B[1:-1, 1:-1] + B[1:-1, :-2] + B[1:-1, 2:] + B[2:, 1:-1] + B[:-2, 1:-1])) + return (A, B), None + + (A, B), _ = lax.scan(body_fn, (A, B), jnp.arange(1, TSTEPS)) + return jnp.sum(A) + + def init_data(N): A = np.empty((N, N), dtype=np.float32) B = np.empty((N, N), dtype=np.float32) @@ -81,6 +96,38 @@ def run_jacobi_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_jacobi_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + TSTEPS, N = (20, 30) + A, B = init_data(N) + jax_A, jax_B = np.copy(A), np.copy(B) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def jacobi_2d_autodiff_kernel(TSTEPS: dc.int32, A: dc.float32[N, N], B: dc.float32[N, N]): + kernel(TSTEPS, A, B) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = jacobi_2d_autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, B, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A, B: kernel_jax(jnp, lax, TSTEPS, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=0) + jax_grad_A = jax_grad(TSTEPS, jax_A, jax_B) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-6) + + def test_cpu(): run_jacobi_2d(dace.dtypes.DeviceType.CPU) @@ -90,6 +137,12 @@ def test_gpu(): run_jacobi_2d(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_jacobi_2d_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_jacobi_2d(dace.dtypes.DeviceType.FPGA) @@ -105,6 +158,7 @@ def test_fpga(): if target == "cpu": run_jacobi_2d(dace.dtypes.DeviceType.CPU) + run_jacobi_2d_autodiff() elif target == "gpu": run_jacobi_2d(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/k2mm_test.py b/tests/npbench/polybench/k2mm_test.py index e7a26833fb..d4ae52363c 100644 --- a/tests/npbench/polybench/k2mm_test.py +++ b/tests/npbench/polybench/k2mm_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK, NL @@ -31,6 +31,11 @@ def k2mm_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[NI, NK], B: d D[:] = alpha * A @ B @ C + beta * D +def k2mm_jax(jnp, alpha, beta, A, B, C, D): + D = alpha * A @ B @ C + beta * D + return jnp.sum(D) + + def initialize(NI, NJ, NK, NL, datatype=np.float64): alpha = datatype(1.5) beta = datatype(1.2) @@ -77,6 +82,51 @@ def run_k2mm(device_type: dace.dtypes.DeviceType): return sdfg +def run_k2mm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize forward data + NI, NJ, NK, NL = sizes["small"] + alpha, beta, A, B, C, D = initialize(NI, NJ, NK, NL) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, A: dc.float64[NI, NK], B: dc.float64[NK, NJ], + C: dc.float64[NJ, NL], D: dc.float64[NI, NL]): + k2mm_kernel(alpha, beta, A, B, C, D) + return np.sum(D) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, + beta, + A, + B, + C, + D, + NI=NI, + NJ=NJ, + NK=NK, + NL=NL, + gradient_A=gradient_A, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, A, B, C, D: k2mm_jax(jnp, alpha, beta, A, B, C, D) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=2)) + jax_grad_A = jax_grad(alpha, beta, A, B, C, D) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_k2mm(dace.dtypes.DeviceType.CPU) @@ -86,6 +136,12 @@ def test_gpu(): run_k2mm(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_k2mm_autodiff() + + @fpga_test(assert_ii_1=False, xilinx=False) def test_fpga(): return run_k2mm(dace.dtypes.DeviceType.FPGA) @@ -101,6 +157,7 @@ def test_fpga(): if target == "cpu": run_k2mm(dace.dtypes.DeviceType.CPU) + run_k2mm_autodiff() elif target == "gpu": run_k2mm(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/k3mm_test.py b/tests/npbench/polybench/k3mm_test.py index 398b30e107..53ffc5a8c6 100644 --- a/tests/npbench/polybench/k3mm_test.py +++ b/tests/npbench/polybench/k3mm_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # NI, NJ, NK, NL, NM @@ -30,6 +30,11 @@ def k3mm_kernel(A: dc.float64[NI, NK], B: dc.float64[NK, NJ], C: dc.float64[NJ, return A @ B @ C @ D +def k3mm_jax(jnp, A, B, C, D): + E = A @ B @ C @ D + return jnp.sum(E) + + def initialize(NI, NJ, NK, NL, NM, datatype=np.float64): A = np.fromfunction(lambda i, j: ((i * j + 1) % NI) / (5 * NI), (NI, NK), dtype=datatype) B = np.fromfunction(lambda i, j: ((i * (j + 1) + 2) % NJ) / (5 * NJ), (NK, NJ), dtype=datatype) @@ -73,6 +78,39 @@ def run_k3mm(device_type: dace.dtypes.DeviceType): return sdfg +def run_k3mm_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize forward data + NI, NJ, NK, NL, NM = sizes["small"] + A, B, C, D = initialize(NI, NJ, NK, NL, NM) + + # Intiialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[NI, NK], B: dc.float64[NK, NJ], C: dc.float64[NJ, NM], D: dc.float64[NM, NL]): + E = k3mm_kernel(A, B, C, D) + return np.sum(E) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, B, C, D, NI=NI, NJ=NJ, NK=NK, NL=NL, NM=NM, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, B, C, D: k3mm_jax(jnp, A, B, C, D) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_A = jax_grad(A, B, C, D) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_k3mm(dace.dtypes.DeviceType.CPU) @@ -82,6 +120,12 @@ def test_gpu(): run_k3mm(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_k3mm_autodiff() + + @fpga_test(assert_ii_1=False, xilinx=False) def test_fpga(): return run_k3mm(dace.dtypes.DeviceType.FPGA) @@ -97,6 +141,7 @@ def test_fpga(): if target == "cpu": run_k3mm(dace.dtypes.DeviceType.CPU) + run_k3mm_autodiff() elif target == "gpu": run_k3mm(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/lu_test.py b/tests/npbench/polybench/lu_test.py index 1786503918..89fa2e8bc2 100644 --- a/tests/npbench/polybench/lu_test.py +++ b/tests/npbench/polybench/lu_test.py @@ -10,6 +10,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.autodiff import add_backward_pass N = dc.symbol('N', dtype=dc.int32) @@ -52,6 +53,43 @@ def init_data(N): return A +def lu_jax_kernel(jnp, lax, A): + n = A.shape[0] + + def outer_loop_body(A, i): + + def inner_loop_1_body(A, j): + + def update_fn(_): + mask = jnp.arange(n) < j + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + new_val = (A[i, j] - A_slice_1 @ A_slice_2) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, lambda _: update_fn(None), lambda _: A, operand=None) + return A, None + + def inner_loop_2_body(A, j): + + def update_fn(_): + mask = jnp.arange(n) < i + A_slice_1 = jnp.where(mask, A[i, :], 0.0) + A_slice_2 = jnp.where(mask, A[:, j], 0.0) + new_val = A[i, j] - A_slice_1 @ A_slice_2 + return A.at[i, j].set(new_val) + + A = lax.cond(j >= i, lambda _: update_fn(None), lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(inner_loop_1_body, A, jnp.arange(n)) + A, _ = lax.scan(inner_loop_2_body, A, jnp.arange(n)) + return A, None + + A, _ = lax.scan(outer_loop_body, A, jnp.arange(n)) + return jnp.sum(A) + + def run_lu(device_type: dace.dtypes.DeviceType): """ Runs LU for the given device @@ -101,6 +139,38 @@ def run_lu(device_type: dace.dtypes.DeviceType): return sdfg +def run_lu_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = 5 + A = init_data(N) + A_jax = jnp.copy(A) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float32[N, N]): + lu_kernel(A) + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda A: lu_jax_kernel(jnp, lax, A) + jax_grad = jax.jit(jax.grad(jax_kernel)) + jax_grad_A = jax_grad(A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-5, atol=1e-5) + + def test_cpu(): run_lu(dace.dtypes.DeviceType.CPU) @@ -110,6 +180,12 @@ def test_gpu(): run_lu(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_lu_autodiff() + + @fpga_test(assert_ii_1=False, xilinx=False) def test_fpga(): return run_lu(dace.dtypes.DeviceType.FPGA) @@ -125,6 +201,7 @@ def test_fpga(): if target == "cpu": run_lu(dace.dtypes.DeviceType.CPU) + run_lu_autodiff() elif target == "gpu": run_lu(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/ludcmp_test.py b/tests/npbench/polybench/ludcmp_test.py index 2ffa681616..3d4af86f7b 100644 --- a/tests/npbench/polybench/ludcmp_test.py +++ b/tests/npbench/polybench/ludcmp_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc @@ -10,6 +11,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -69,6 +71,100 @@ def ground_truth(A, b): return x, y +def ludcmp_jax_kernel(jnp, lax, A, b): + n = A.shape[0] + x = jnp.zeros_like(b) + y = jnp.zeros_like(b) + + def outer_loop_body_1(A, i): + + def inner_loop_1_body(A, j): + + def update(): + A_slice_1 = jnp.where(jnp.arange(n) < j, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(n) < j, A[:, j], 0.0) + new_val = (A[i, j] - A_slice_1 @ A_slice_2) / A[j, j] + return A.at[i, j].set(new_val) + + A = lax.cond(j < i, lambda _: update(), lambda _: A, operand=None) + return A, None + + def inner_loop_2_body(A, j): + + def update(): + A_slice_1 = jnp.where(jnp.arange(n) < i, A[i, :], 0.0) + A_slice_2 = jnp.where(jnp.arange(n) < i, A[:, j], 0.0) + new_val = A[i, j] - A_slice_1 @ A_slice_2 + return A.at[i, j].set(new_val) + + A = lax.cond(j >= i, lambda _: update(), lambda _: A, operand=None) + return A, None + + A, _ = lax.scan(inner_loop_1_body, A, jnp.arange(n)) + A, _ = lax.scan(inner_loop_2_body, A, jnp.arange(n)) + return A, None + + A, _ = lax.scan(outer_loop_body_1, A, jnp.arange(n)) + + def loop_body_2_scan(loop_vars, i): + A, y, b = loop_vars + A_slice = jnp.where(jnp.arange(n) < i, A[i, :], 0.0) + y_slice = jnp.where(jnp.arange(n) < i, y, 0.0) + new_y = b[i] - A_slice @ y_slice + y = y.at[i].set(new_y) + return (A, y, b), None + + (A, y, b), _ = lax.scan(loop_body_2_scan, (A, y, b), jnp.arange(n)) + + def loop_body_3_scan(loop_vars, t): + A, x, y = loop_vars + i = n - 1 - t # reverse order + A_slice = jnp.where(jnp.arange(n) > i, A[i, :], 0.0) + x_slice = jnp.where(jnp.arange(n) > i, x, 0.0) + new_x = (y[i] - A_slice @ x_slice) / A[i, i] + x = x.at[i].set(new_x) + return (A, x, y), None + + (A, x, y), _ = lax.scan(loop_body_3_scan, (A, x, y), jnp.arange(n)) + + return jnp.sum(x) + + +def run_ludcmp_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = sizes["mini"] + A, b = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(A: dc.float64[N, N], b: dc.float64[N]): + x, y = ludcmp_kernel(A, b) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(A, b, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda A, b: ludcmp_jax_kernel(jnp, lax, A, b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + A_jax, b_jax = initialize(N) + jax_grad_A = jax_grad(A_jax, b_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def run_ludcmp(device_type: dace.dtypes.DeviceType): ''' Runs Ludcmp for the given device @@ -115,6 +211,17 @@ def test_gpu(): run_ludcmp(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_ludcmp_autodiff() + os.environ['DACE_testing_serialization'] = last_value + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_ludcmp(dace.dtypes.DeviceType.FPGA) @@ -130,6 +237,7 @@ def test_fpga(): if target == "cpu": run_ludcmp(dace.dtypes.DeviceType.CPU) + run_ludcmp_autodiff() elif target == "gpu": run_ludcmp(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/mvt_test.py b/tests/npbench/polybench/mvt_test.py index a45024a15c..0920c41dd1 100644 --- a/tests/npbench/polybench/mvt_test.py +++ b/tests/npbench/polybench/mvt_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -35,6 +35,12 @@ def initialize(N, datatype=np.float64): return x1, x2, y_1, y_2, A +def mvt_jax_kernel(jnp, x1, x2, y_1, y_2, A): + x1 += A @ y_1 + x2 += y_2 @ A + return jnp.sum(x2) + + def run_mvt(device_type: dace.dtypes.DeviceType): ''' Runs MVT for the given device @@ -73,6 +79,41 @@ def run_mvt(device_type: dace.dtypes.DeviceType): return sdfg +def run_mvt_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (polybench mini size) + N = sizes["mini"] + x1, x2, y_1, y_2, A = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(x1: dc.float64[N], x2: dc.float64[N], y_1: dc.float64[N], y_2: dc.float64[N], A: dc.float64[N, + N]): + mvt_kernel(x1, x2, y_1, y_2, A) + return np.sum(x2) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(x1, x2, y_1, y_2, A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda x1, x2, y_1, y_2, A: mvt_jax_kernel(jnp, x1, x2, y_1, y_2, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4)) + x1_jax, x2_jax, y_1_jax, y_2_jax, A_jax = initialize(N) + jax_grad_A = jax_grad(x1_jax, x2_jax, y_1_jax, y_2_jax, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_mvt(dace.dtypes.DeviceType.CPU) @@ -82,6 +123,12 @@ def test_gpu(): run_mvt(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_mvt_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_mvt(dace.dtypes.DeviceType.FPGA) @@ -97,6 +144,7 @@ def test_fpga(): if target == "cpu": run_mvt(dace.dtypes.DeviceType.CPU) + run_mvt_autodiff() elif target == "gpu": run_mvt(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/seidel_2d_test.py b/tests/npbench/polybench/seidel_2d_test.py index 7d1f8a8389..6e1565e8d2 100644 --- a/tests/npbench/polybench/seidel_2d_test.py +++ b/tests/npbench/polybench/seidel_2d_test.py @@ -10,6 +10,7 @@ from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Dataset sizes # TSTEPS, N @@ -35,6 +36,32 @@ def initialize(N, datatype=np.float64): return A +def seidel_2d_jax_kernel(jnp, lax, TSTEPS, A): + """JAX implementation using efficient lax.scan operations""" + N = A.shape[0] + + def loop1_body(A, t): + + def loop2_body(A, i): + update_val = (A[i, 1:-1] + (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + A[i, 2:] + A[i + 1, :-2] + + A[i + 1, 1:-1] + A[i + 1, 2:])) + A = A.at[i, 1:-1].set(update_val) + + def loop3_body(A, j): + new_val = (A[i, j] + A[i, j - 1]) / 9.0 + A = A.at[i, j].set(new_val) + return A, None + + A, _ = lax.scan(loop3_body, A, jnp.arange(1, N - 1)) + return A, None + + A, _ = lax.scan(loop2_body, A, jnp.arange(1, N - 1)) + return A, None + + A, _ = lax.scan(loop1_body, A, jnp.arange(TSTEPS - 1)) + return jnp.sum(A) + + def ground_truth(TSTEPS, N, A): for t in range(0, TSTEPS - 1): @@ -87,6 +114,47 @@ def run_seidel_2d(device_type: dace.dtypes.DeviceType): return sdfg +def run_seidel_2d_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (test size for efficiency) + TSTEPS, N = (2, 8) + A = initialize(N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output using __return pattern + @dc.program + def autodiff_kernel(TSTEPS: dc.int64, A: dc.float64[N, N]): + for t in range(0, TSTEPS - 1): + for i in range(1, N - 1): + A[i, 1:-1] += (A[i - 1, :-2] + A[i - 1, 1:-1] + A[i - 1, 2:] + A[i, 2:] + A[i + 1, :-2] + + A[i + 1, 1:-1] + A[i + 1, 2:]) + for j in range(1, N - 1): + A[i, j] += A[i, j - 1] + A[i, j] /= 9.0 + return np.sum(A) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(TSTEPS, A, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda TSTEPS, A: seidel_2d_jax_kernel(jnp, lax, TSTEPS, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + A_jax = initialize(N) + jax_grad_A = jax_grad(TSTEPS, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_seidel_2d(dace.dtypes.DeviceType.CPU) @@ -96,6 +164,12 @@ def test_gpu(): run_seidel_2d(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_seidel_2d_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_seidel_2d(dace.dtypes.DeviceType.FPGA) @@ -111,6 +185,7 @@ def test_fpga(): if target == "cpu": run_seidel_2d(dace.dtypes.DeviceType.CPU) + run_seidel_2d_autodiff() elif target == "gpu": run_seidel_2d(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/symm_test.py b/tests/npbench/polybench/symm_test.py index d0bae1edfc..fc8244c0f4 100644 --- a/tests/npbench/polybench/symm_test.py +++ b/tests/npbench/polybench/symm_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -43,6 +43,32 @@ def initialize(M, N, datatype=np.float64): return alpha, beta, C, A, B +def symm_jax_kernel(jnp, lax, alpha, beta, C, A, B): + temp2 = jnp.empty((C.shape[1], ), dtype=C.dtype) + C = C * beta + + def row_update_body(carry, i): + C, temp2 = carry + + def col_update_body(carry_inner, j): + C, temp2 = carry_inner + + A_slice = jnp.where(jnp.arange(A.shape[1]) < i, A[i, :], 0.0) + B_slice = jnp.where(jnp.arange(B.shape[0]) < i, B[:, j], 0.0) + + updated_col = C[:, j] + (alpha * B[i, j] * A_slice) + C = lax.dynamic_update_slice(C, updated_col[:, None], (0, j)) + temp2 = temp2.at[j].set(B_slice @ A_slice) + return (C, temp2), jnp.array(0) + + (C, temp2), _ = lax.scan(col_update_body, (C, temp2), jnp.arange(C.shape[1])) + C = C.at[i, :].add(alpha * B[i, :] * A[i, i] + alpha * temp2) + return (C, temp2), jnp.array(0) + + (C, temp2), _ = lax.scan(row_update_body, (C, temp2), jnp.arange(C.shape[0])) + return jnp.sum(C) + + def ground_truth(alpha, beta, C, A, B): temp2 = np.empty((C.shape[1], ), dtype=C.dtype) @@ -90,6 +116,42 @@ def run_symm(device_type: dace.dtypes.DeviceType): return sdfg +def run_symm_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, beta, C, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[M, N], A: dc.float64[M, M], + B: dc.float64[M, N]): + symm_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A, B: symm_jax_kernel(jnp, lax, alpha, beta, C, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_symm(dace.dtypes.DeviceType.CPU) @@ -99,6 +161,12 @@ def test_gpu(): run_symm(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_symm_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_symm(dace.dtypes.DeviceType.FPGA) @@ -114,6 +182,7 @@ def test_fpga(): if target == "cpu": run_symm(dace.dtypes.DeviceType.CPU) + run_symm_autodiff() elif target == "gpu": run_symm(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/syr2k_test.py b/tests/npbench/polybench/syr2k_test.py index 0edb5a3045..4453755471 100644 --- a/tests/npbench/polybench/syr2k_test.py +++ b/tests/npbench/polybench/syr2k_test.py @@ -5,11 +5,10 @@ import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test +from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG -from dace.transformation.dataflow import StreamingMemory, StreamingComposition -from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -37,6 +36,51 @@ def initialize(M, N, datatype=np.float64): return alpha, beta, C, A, B +def syr2k_jax_kernel(jnp, lax, alpha, beta, C, A, B): + m = A.shape[0] # outer loop range + n = A.shape[1] # inner loop range + + def outer_body_fun(carry, i): + # Unpack loop variables for the outer loop. + alpha, beta, C, A, B = carry + + # Outer-loop update: scale row i of C by beta, but only for columns < i+1. + C_slice = jnp.where(jnp.arange(C.shape[1]) < (i + 1), C[i, :], 0.0) + C_slice = C_slice * beta + C_slice = jnp.where(jnp.arange(C.shape[1]) < (i + 1), C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + # Define the inner scan that will update row i of C using index k. + def inner_body_fun(inner_carry, k): + # Unpack inner loop variables. + alpha_inner, C_inner, A_inner, B_inner = inner_carry + + # For A_update_slice and B_update_slice, only entries for indices < i+1 are used. + A_update_slice = jnp.where(jnp.arange(A_inner.shape[0]) < (i + 1), A_inner[:, k], 0.0) + A_update_slice = A_update_slice * (alpha_inner * B_inner[i, k]) + + B_update_slice = jnp.where(jnp.arange(B_inner.shape[0]) < (i + 1), B_inner[:, k], 0.0) + B_update_slice = B_update_slice * (alpha_inner * A_inner[i, k]) + + # Compute an update for row i of C: take its current values (only for indices < i+1) + # and add the contributions from A_update_slice and B_update_slice. + C_update_slice = jnp.where(jnp.arange(C_inner.shape[1]) < (i + 1), C_inner[i, :], 0.0) + C_update_slice = C_update_slice + A_update_slice + B_update_slice + # For indices not less than i+1, keep the original C[i, :]. + C_update_slice = jnp.where(jnp.arange(C_inner.shape[1]) < (i + 1), C_update_slice, C_inner[i, :]) + # Update row i of C. + C_inner = lax.dynamic_update_slice(C_inner, C_update_slice[None, :], (i, 0)) + return (alpha_inner, C_inner, A_inner, B_inner), None + + # Run the inner scan over k from 0 to n-1. + (alpha, C, A, B), _ = lax.scan(inner_body_fun, (alpha, C, A, B), jnp.arange(n)) + return (alpha, beta, C, A, B), None + + # Run the outer scan over i from 0 to m-1. + (alpha, beta, C, A, B), _ = lax.scan(outer_body_fun, (alpha, beta, C, A, B), jnp.arange(m)) + return jnp.sum(C) + + def ground_truth(alpha, beta, C, A, B): for i in range(A.shape[0]): @@ -78,6 +122,42 @@ def run_syr2k(device_type: dace.dtypes.DeviceType): return sdfg +def run_syr2k_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, beta, C, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, beta: dc.float64, C: dc.float64[N, N], A: dc.float64[N, M], + B: dc.float64[N, M]): + syr2k_kernel(alpha, beta, C, A, B) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, beta, C, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A, B: syr2k_jax_kernel(jnp, lax, alpha, beta, C, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_syr2k(dace.dtypes.DeviceType.CPU) @@ -87,6 +167,12 @@ def test_gpu(): run_syr2k(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_syr2k_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_syr2k(dace.dtypes.DeviceType.FPGA) @@ -102,6 +188,7 @@ def test_fpga(): if target == "cpu": run_syr2k(dace.dtypes.DeviceType.CPU) + run_syr2k_autodiff() elif target == "gpu": run_syr2k(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/syrk_test.py b/tests/npbench/polybench/syrk_test.py index 6e92411128..fc4c10ac12 100644 --- a/tests/npbench/polybench/syrk_test.py +++ b/tests/npbench/polybench/syrk_test.py @@ -10,6 +10,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, MapFusionVertical, StreamingComposition, PruneConnectors from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt +from dace.autodiff import add_backward_pass # M, N sizes = {"mini": (20, 30), "small": (60, 80), "medium": (200, 240), "large": (1000, 1200), "extra-large": (2000, 2600)} @@ -42,6 +43,48 @@ def init_data(N, M): return alpha, beta, C, A +def syrk_jax_kernel(jnp, lax, alpha, beta, C, A): + m = A.shape[0] # number of rows + n = A.shape[1] # number of columns + + def outer_body_fun(carry, i): + # Unpack outer loop carry. + alpha, beta, C, A = carry + + # Outer loop update: scale row i of C by beta for indices < i+1. + col_mask = jnp.arange(C.shape[1]) < (i + 1) + C_slice = jnp.where(col_mask, C[i, :], 0.0) + C_slice = C_slice * beta + # Preserve the original values for indices >= i+1. + C_slice = jnp.where(col_mask, C_slice, C[i, :]) + C = lax.dynamic_update_slice(C, C_slice[None, :], (i, 0)) + + # Define the inner loop which updates row i of C using column updates from A. + def inner_body_fun(inner_carry, k): + alpha_inner, C_inner, A_inner = inner_carry + + # Compute an update slice from A[:, k] for rows < i+1. + row_mask = jnp.arange(A_inner.shape[0]) < (i + 1) + A_update_slice = jnp.where(row_mask, A_inner[:, k], 0.0) + A_update_slice = A_update_slice * (alpha_inner * A_inner[i, k]) + + # Update C[i, :] by adding the A_update_slice, only for columns < i+1. + col_mask_inner = jnp.arange(C_inner.shape[1]) < (i + 1) + C_update_slice = jnp.where(col_mask_inner, C_inner[i, :], 0.0) + C_update_slice = C_update_slice + A_update_slice + C_update_slice = jnp.where(col_mask_inner, C_update_slice, C_inner[i, :]) + C_inner = lax.dynamic_update_slice(C_inner, C_update_slice[None, :], (i, 0)) + return (alpha_inner, C_inner, A_inner), None + + # Run the inner loop over k = 0,..., n-1. + (alpha, C, A), _ = lax.scan(inner_body_fun, (alpha, C, A), jnp.arange(n)) + return (alpha, beta, C, A), None + + # Run the outer loop over i = 0,..., m-1. + (alpha, beta, C, A), _ = lax.scan(outer_body_fun, (alpha, beta, C, A), jnp.arange(m)) + return jnp.sum(C) + + def ground_truth(N, M, alpha, beta, C, A): for i in range(N): @@ -86,6 +129,38 @@ def run_syrk(device_type: dace.dtypes.DeviceType): return sdfg +def run_syrk_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) - note the order swap for this test + M, N = sizes["mini"] + alpha, beta, C, A = init_data(N, M) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float32) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float32, beta: dc.float32, C: dc.float32[N, N], A: dc.float32[N, M]): + kernel(alpha, beta, C, A) + return np.sum(C) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha=alpha, beta=beta, C=C, A=A, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, beta, C, A: syrk_jax_kernel(jnp, lax, alpha, beta, C, A) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=3), static_argnums=(0, 1)) + alpha, beta, C_jax, A_jax = init_data(N, M) + jax_grad_A = jax_grad(alpha, beta, C_jax, A_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A, rtol=1e-6, atol=1e-5) + + def test_cpu(): run_syrk(dace.dtypes.DeviceType.CPU) @@ -95,6 +170,12 @@ def test_gpu(): run_syrk(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_syrk_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_syrk(dace.dtypes.DeviceType.FPGA) @@ -110,6 +191,7 @@ def test_fpga(): if target == "cpu": run_syrk(dace.dtypes.DeviceType.CPU) + run_syrk_autodiff() elif target == "gpu": run_syrk(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/trisolv_test.py b/tests/npbench/polybench/trisolv_test.py index d9ec2a1802..1442b12c87 100644 --- a/tests/npbench/polybench/trisolv_test.py +++ b/tests/npbench/polybench/trisolv_test.py @@ -9,7 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # N @@ -30,6 +30,20 @@ def initialize(N, datatype=np.float64): return L, x, b +def trisolv_jax_kernel(jnp, lax, L, x, b): + + def scan_body(carry, i): + L, x, b = carry + mask = jnp.arange(x.shape[0]) < i + products = jnp.where(mask, L[i, :] * x, 0.0) + dot_product = jnp.sum(products) + x = x.at[i].set((b[i] - dot_product) / L[i, i]) + return (L, x, b), None + + (L, x, b), _ = lax.scan(scan_body, (L, x, b), jnp.arange(x.shape[0])) + return jnp.sum(x) + + def ground_truth(L, x, b): for i in range(x.shape[0]): x[i] = (b[i] - L[i, :i] @ x[:i]) / L[i, i] @@ -71,6 +85,41 @@ def run_trisolv(device_type: dace.dtypes.DeviceType): return sdfg +def run_trisolv_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + N = sizes["mini"] + L, x, b = initialize(N) + + # Initialize gradient computation data + gradient_L = np.zeros_like(L) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(L: dc.float64[N, N], x: dc.float64[N], b: dc.float64[N]): + trisolv_kernel(L, x, b) + return np.sum(x) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["L"], outputs=["__return"]) + sdfg(L, x, np.copy(b), N=N, gradient_L=gradient_L, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda L, x, b: trisolv_jax_kernel(jnp, lax, L, x, b) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + L_jax, x_jax, b_jax = initialize(N) + jax_grad_L = jax_grad(L_jax, x_jax, b_jax) + np.testing.assert_allclose(gradient_L, jax_grad_L) + + def test_cpu(): run_trisolv(dace.dtypes.DeviceType.CPU) @@ -80,6 +129,12 @@ def test_gpu(): run_trisolv(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_trisolv_autodiff() + + @fpga_test(assert_ii_1=False, xilinx=False) def test_fpga(): return run_trisolv(dace.dtypes.DeviceType.FPGA) @@ -95,6 +150,7 @@ def test_fpga(): if target == "cpu": run_trisolv(dace.dtypes.DeviceType.CPU) + run_trisolv_autodiff() elif target == "gpu": run_trisolv(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/polybench/trmm_test.py b/tests/npbench/polybench/trmm_test.py index 51e2367df1..53b5d519e0 100644 --- a/tests/npbench/polybench/trmm_test.py +++ b/tests/npbench/polybench/trmm_test.py @@ -1,15 +1,16 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. # Original application code: NPBench - https://github.com/spcl/npbench +import os import dace.dtypes import numpy as np import dace as dc import pytest import argparse -from dace.fpga_testing import fpga_test, xilinx_test +from dace.fpga_testing import fpga_test from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize, fpga_auto_opt -from dace.config import set_temporary +from dace.autodiff import add_backward_pass # Data set sizes # M, N @@ -37,6 +38,27 @@ def initialize(M, N, datatype=np.float64): return alpha, A, B +def trmm_jax_kernel(jnp, lax, alpha, A, B): + + def outer_body(carry, i): + B = carry + + def inner_body(B, j): + + mask = (jnp.arange(A.shape[0]) > i).astype(A.dtype) + dot_val = jnp.sum(A[:, i] * B[:, j] * mask) + new_val = B[i, j] + dot_val + B = B.at[i, j].set(new_val) + return B, jnp.array(0) + + B, _ = lax.scan(inner_body, B, jnp.arange(B.shape[1])) + return B, jnp.array(0) + + B, _ = lax.scan(outer_body, B, jnp.arange(B.shape[0])) + B = B * alpha + return jnp.sum(B) + + def ground_truth(alpha, A, B): for i in range(B.shape[0]): for j in range(B.shape[1]): @@ -80,6 +102,41 @@ def run_trmm(device_type: dace.dtypes.DeviceType): return sdfg +def run_trmm_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (polybench mini size) + M, N = sizes["mini"] + alpha, A, B = initialize(M, N) + + # Initialize gradient computation data + gradient_A = np.zeros_like(A) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(alpha: dc.float64, A: dc.float64[M, M], B: dc.float64[M, N]): + trmm_kernel(alpha, A, B) + return np.sum(B) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["A"], outputs=["__return"]) + sdfg(alpha, A, B, M=M, N=N, gradient_A=gradient_A, gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda alpha, A, B: trmm_jax_kernel(jnp, lax, alpha, A, B) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=1), static_argnums=(0, )) + alpha, A_jax, B_jax = initialize(M, N) + jax_grad_A = jax_grad(alpha, A_jax, B_jax) + np.testing.assert_allclose(gradient_A, jax_grad_A) + + def test_cpu(): run_trmm(dace.dtypes.DeviceType.CPU) @@ -89,6 +146,17 @@ def test_gpu(): run_trmm(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + # Serialization causes issues, we temporarily disable it + # TODO: open an issue to fix the serialization stability problem + last_value = os.environ.get('DACE_testing_serialization', '0') + os.environ['DACE_testing_serialization'] = '0' + run_trmm_autodiff() + os.environ['DACE_testing_serialization'] = last_value + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_trmm(dace.dtypes.DeviceType.FPGA) @@ -104,6 +172,7 @@ def test_fpga(): if target == "cpu": run_trmm(dace.dtypes.DeviceType.CPU) + run_trmm_autodiff() elif target == "gpu": run_trmm(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/weather_stencils/hdiff_test.py b/tests/npbench/weather_stencils/hdiff_test.py index bd0150af91..b42b203342 100644 --- a/tests/npbench/weather_stencils/hdiff_test.py +++ b/tests/npbench/weather_stencils/hdiff_test.py @@ -9,6 +9,7 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass I, J, K = (dc.symbol(s, dtype=dc.int64) for s in ('I', 'J', 'K')) @@ -38,6 +39,32 @@ def hdiff_kernel(in_field: dc.float64[I + 4, J + 4, K], out_field: dc.float64[I, fly_field[:, 1:, :] - fly_field[:, :-1, :]) +def hdiff_jax_kernel(jnp, in_field, out_field, coeff): + I, J, K = out_field.shape[0], out_field.shape[1], out_field.shape[2] + lap_field = 4.0 * in_field[1:I + 3, 1:J + 3, :] - (in_field[2:I + 4, 1:J + 3, :] + in_field[0:I + 2, 1:J + 3, :] + + in_field[1:I + 3, 2:J + 4, :] + in_field[1:I + 3, 0:J + 2, :]) + + res = lap_field[1:, 1:J + 1, :] - lap_field[:-1, 1:J + 1, :] + flx_field = jnp.where( + (res * (in_field[2:I + 3, 2:J + 2, :] - in_field[1:I + 2, 2:J + 2, :])) > 0, + 0, + res, + ) + + res = lap_field[1:I + 1, 1:, :] - lap_field[1:I + 1, :-1, :] + fly_field = jnp.where( + (res * (in_field[2:I + 2, 2:J + 3, :] - in_field[2:I + 2, 1:J + 2, :])) > 0, + 0, + res, + ) + + out_field = out_field.at[:, :, :].set( + in_field[2:I + 2, 2:J + 2, :] - coeff[:, :, :] * + (flx_field[1:, :, :] - flx_field[:-1, :, :] + fly_field[:, 1:, :] - fly_field[:, :-1, :])) + + return jnp.sum(out_field) + + def initialize(I, J, K): from numpy.random import default_rng rng = default_rng(42) @@ -105,6 +132,47 @@ def run_hdiff(device_type: dace.dtypes.DeviceType): return sdfg +def run_hdiff_autodiff(): + import jax + import jax.numpy as jnp + + # Initialize data (npbench small size) + I, J, K = 64, 64, 60 + in_field, out_field, coeff = initialize(I, J, K) + + # Initialize gradient computation data + gradient_in_field = np.zeros_like(in_field) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(in_field: dc.float64[I + 4, J + 4, K], out_field: dc.float64[I, J, K], coeff: dc.float64[I, J, + K]): + hdiff_kernel(in_field, out_field, coeff) + return np.sum(out_field) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["in_field"], outputs=["__return"]) + sdfg(in_field, + out_field, + coeff, + I=I, + J=J, + K=K, + gradient_in_field=gradient_in_field, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda in_field, out_field, coeff: hdiff_jax_kernel(jnp, in_field, out_field, coeff) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=0)) + jax_grad_in_field = jax_grad(in_field, out_field, coeff) + np.testing.assert_allclose(gradient_in_field, jax_grad_in_field) + + def test_cpu(): run_hdiff(dace.dtypes.DeviceType.CPU) @@ -114,6 +182,12 @@ def test_gpu(): run_hdiff(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_hdiff_autodiff() + + @fpga_test(assert_ii_1=False) def test_fpga(): return run_hdiff(dace.dtypes.DeviceType.FPGA) @@ -129,6 +203,7 @@ def test_fpga(): if target == "cpu": run_hdiff(dace.dtypes.DeviceType.CPU) + run_hdiff_autodiff() elif target == "gpu": run_hdiff(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/npbench/weather_stencils/vadv_test.py b/tests/npbench/weather_stencils/vadv_test.py index cf01a0cd31..f74a46bae1 100644 --- a/tests/npbench/weather_stencils/vadv_test.py +++ b/tests/npbench/weather_stencils/vadv_test.py @@ -9,6 +9,8 @@ from dace.transformation.interstate import FPGATransformSDFG, InlineSDFG from dace.transformation.dataflow import StreamingMemory, StreamingComposition from dace.transformation.auto.auto_optimize import auto_optimize +from dace.autodiff import add_backward_pass + # Sample constants BET_M = 0.5 BET_P = 0.5 @@ -89,9 +91,92 @@ def vadv_kernel(utens_stage: dc.float64[I, J, K], u_stage: dc.float64[I, J, K], for k in range(K - 2, -1, -1): # datacol = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] - datacol[:] = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] - data_col[:] = datacol - utens_stage[:, :, k] = dtr_stage * (datacol - u_pos[:, :, k]) + data_col[:] = dcol[:, :, k] - ccol[:, :, k] * data_col[:, :] + utens_stage[:, :, k] = dtr_stage * (data_col - u_pos[:, :, k]) + + +def vadv_jax_kernel(jnp, lax, utens_stage, u_stage, wcon, u_pos, utens, dtr_stage): + I, J, K = utens_stage.shape[0], utens_stage.shape[1], utens_stage.shape[2] + # Allocate working arrays. + ccol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + dcol = jnp.empty((I, J, K), dtype=utens_stage.dtype) + data_col = jnp.empty((I, J), dtype=utens_stage.dtype) + + # --- Loop 1: for k in range(0, 1) --- + def loop1_body(carry, k): + ccol, dcol = carry + # Note: 0+1 is just 1. + gcv = 0.25 * (wcon[1:, :, 1] + wcon[:-1, :, 1]) + cs = gcv * BET_M + bs = gcv * BET_P + bcol = dtr_stage - bs + # update the d column correction term. + correction_term = -cs * (u_stage[:, :, k + 1] - u_stage[:, :, k]) + divided = 1.0 / bcol + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set( + (dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) * divided) + return (ccol, dcol), None + + (ccol, dcol), _ = lax.scan(loop1_body, (ccol, dcol), jnp.arange(0, 1)) + + # --- Loop 2: for k in range(1, K-1) --- + def loop2_body(carry, k): + ccol, dcol = carry + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + gcv = 0.25 * (wcon[1:, :, k + 1] + wcon[:-1, :, k + 1]) + as_ = gav * BET_M + cs = gcv * BET_M + bs = gcv * BET_P + acol = gav * BET_P + bcol = dtr_stage - acol - bs + correction_term = (-as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) - cs * + (u_stage[:, :, k + 1] - u_stage[:, :, k])) + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + ccol = ccol.at[:, :, k].set(bs * divided) + dcol = dcol.at[:, :, k].set( + ((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) - + dcol[:, :, k - 1] * acol) * divided) + return (ccol, dcol), None + + (ccol, dcol), _ = lax.scan(loop2_body, (ccol, dcol), jnp.arange(1, K - 1)) + + # --- Loop 3: for k in range(K-1, K) --- + def loop3_body(dcol, k): + gav = -0.25 * (wcon[1:, :, k] + wcon[:-1, :, k]) + as_ = gav * BET_M + acol = gav * BET_P + bcol = dtr_stage - acol + correction_term = -as_ * (u_stage[:, :, k - 1] - u_stage[:, :, k]) + divided = 1.0 / (bcol - ccol[:, :, k - 1] * acol) + dcol = dcol.at[:, :, k].set( + ((dtr_stage * u_pos[:, :, k] + utens[:, :, k] + utens_stage[:, :, k] + correction_term) - + dcol[:, :, k - 1] * acol) * divided) + return dcol, None + + dcol, _ = lax.scan(loop3_body, dcol, jnp.arange(K - 1, K)) + + # --- Loop 4: for k in range(K-1, K) --- + def loop4_body(carry, k): + data_col, utens_stage = carry + datacol = dcol[:, :, k] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, k].set(dtr_stage * (datacol - u_pos[:, :, k])) + return (data_col, utens_stage), None + + (data_col, utens_stage), _ = lax.scan(loop4_body, (data_col, utens_stage), jnp.arange(K - 1, K)) + + # --- Loop 5: for k in range(0, K-1) with reverse order --- + def loop5_body(carry, k): + data_col, utens_stage = carry + idx = (K - 2) - k # Reverse order: when k=0, idx=K-2; when k=K-2, idx=0. + datacol = dcol[:, :, idx] - ccol[:, :, idx] * data_col[:, :] + data_col = data_col.at[:].set(datacol) + utens_stage = utens_stage.at[:, :, idx].set(dtr_stage * (datacol - u_pos[:, :, idx])) + return (data_col, utens_stage), None + + (data_col, utens_stage), _ = lax.scan(loop5_body, (data_col, utens_stage), jnp.arange(0, K - 1)) + return jnp.sum(utens_stage) def initialize(I, J, K): @@ -211,6 +296,55 @@ def run_vadv(device_type: dace.dtypes.DeviceType): return sdfg +def run_vadv_autodiff(): + import jax + import jax.numpy as jnp + import jax.lax as lax + + # Initialize data (npbench small size) + I, J, K = 4, 4, 3 + dtr_stage, utens_stage, u_stage, wcon, u_pos, utens = initialize(I, J, K) + dtr_stage_jax, utens_stage_jax, u_stage_jax, wcon_jax, u_pos_jax, utens_jax = [ + np.copy(arr) for arr in (dtr_stage, utens_stage, u_stage, wcon, u_pos, utens) + ] + + # Initialize gradient computation data + gradient_utens = np.zeros_like(utens) + gradient___return = np.ones((1, ), dtype=np.float64) + + # Define sum reduction for the output + @dc.program + def autodiff_kernel(utens_stage: dc.float64[I, J, K], u_stage: dc.float64[I, J, K], wcon: dc.float64[I + 1, J, K], + u_pos: dc.float64[I, J, K], utens: dc.float64[I, J, K], dtr_stage: dc.float64): + vadv_kernel(utens_stage, u_stage, wcon, u_pos, utens, dtr_stage) + return np.sum(utens_stage) + + # Add the backward pass to the SDFG + sdfg = autodiff_kernel.to_sdfg() + add_backward_pass(sdfg=sdfg, inputs=["utens"], outputs=["__return"]) + sdfg(utens_stage, + u_stage, + wcon, + u_pos, + utens, + dtr_stage, + I=I, + J=J, + K=K, + gradient_utens=gradient_utens, + gradient___return=gradient___return) + + # Enable float64 support + jax.config.update("jax_enable_x64", True) + + # Numerically validate vs JAX + jax_kernel = lambda utens_stage, u_stage, wcon, u_pos, utens, dtr_stage: vadv_jax_kernel( + jnp, lax, utens_stage, u_stage, wcon, u_pos, utens, dtr_stage) + jax_grad = jax.jit(jax.grad(jax_kernel, argnums=4)) + jax_grad_utens = jax_grad(utens_stage_jax, u_stage_jax, wcon_jax, u_pos_jax, utens_jax, dtr_stage_jax) + np.testing.assert_allclose(gradient_utens, jax_grad_utens) + + def test_cpu(monkeypatch): # NOTE: Serialization fails because of "k - k" expression simplified to "0" monkeypatch.setenv("DACE_testing_serialization", 0) @@ -222,6 +356,12 @@ def test_gpu(): run_vadv(dace.dtypes.DeviceType.GPU) +@pytest.mark.autodiff +def test_autodiff(): + pytest.importorskip("jax", reason="jax not installed. Please install with: pip install dace[ml-testing]") + run_vadv_autodiff() + + @pytest.mark.skip(reason="Xilinx internal compiler error") @fpga_test(assert_ii_1=False) def test_fpga(): @@ -238,6 +378,7 @@ def test_fpga(): if target == "cpu": run_vadv(dace.dtypes.DeviceType.CPU) + run_vadv_autodiff() elif target == "gpu": run_vadv(dace.dtypes.DeviceType.GPU) elif target == "fpga": diff --git a/tests/onnx/pure_expansions/test_conv_expansion.py b/tests/onnx/pure_expansions/test_conv_expansion.py new file mode 100644 index 0000000000..8ca3616b29 --- /dev/null +++ b/tests/onnx/pure_expansions/test_conv_expansion.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import dace +import dace.libraries.onnx as donnx +import torch +import torch.nn.functional as F +import numpy as np + + +@pytest.mark.onnx +@pytest.mark.parametrize("num_in_channels, kernel_size, num_filters, bias", + [(1, (3, 3), 8, True), (8, (3, 3), 3, False), (8, (5, 5), 3, True), (8, (4, 4), 3, False)]) +def test_conv_simple(num_in_channels, kernel_size, num_filters, bias): + + batch_size = 8 + + X = np.random.rand(batch_size, num_in_channels, 32, 32).astype(np.float32) + W = np.random.rand(num_filters, num_in_channels, *kernel_size).astype(np.float32) + + if bias: + B = np.random.rand(num_filters).astype(np.float32) + torch_Z = F.conv2d(torch.from_numpy(X), torch.from_numpy(W), bias=torch.from_numpy(B)).numpy() + else: + B = None + torch_Z = F.conv2d(torch.from_numpy(X), torch.from_numpy(W)).numpy() + + dace_Z = np.zeros_like(torch_Z) + + if bias: + + @dace.program + def conv(X_: dace.float32[tuple(X.shape)], W_: dace.float32[tuple(W.shape)], B_: dace.float32[tuple(B.shape)], + Z_: dace.float32[tuple(torch_Z.shape)]): + donnx.ONNXConv(X=X_, W=W_, B=B_, Y=Z_) + else: + + @dace.program + def conv(X_: dace.float32[tuple(X.shape)], W_: dace.float32[tuple(W.shape)], + Z_: dace.float32[tuple(torch_Z.shape)]): + donnx.ONNXConv(X=X_, W=W_, Y=Z_) + + sdfg = conv.to_sdfg() + sdfg.expand_library_nodes() + + if bias: + sdfg(X_=X, W_=W, Z_=dace_Z, B_=B) + else: + sdfg(X_=X, W_=W, Z_=dace_Z) + + print(torch_Z - dace_Z) + assert np.allclose(torch_Z, dace_Z) + + +if __name__ == "__main__": + # Test with different parameter combinations + params = [(1, (3, 3), 8, True), (8, (3, 3), 3, False), (8, (5, 5), 3, True), (8, (4, 4), 3, False)] + for num_in_channels, kernel_size, num_filters, bias in params: + test_conv_simple(num_in_channels, kernel_size, num_filters, bias) diff --git a/tests/onnx/pure_expansions/test_expansion_utils.py b/tests/onnx/pure_expansions/test_expansion_utils.py new file mode 100644 index 0000000000..d71c54dc5a --- /dev/null +++ b/tests/onnx/pure_expansions/test_expansion_utils.py @@ -0,0 +1,42 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_sqrt_expansion(): + # sqrt expansion makes use of the program_for_node function + sdfg = dace.SDFG("test_sqrt_expansion") + + sdfg.add_array("inp", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.float32) + + state = sdfg.add_state() + access_in = state.add_access("inp") + access_result = state.add_access("__return") + + op_node = donnx.ONNXSqrt("sqrt") + + state.add_node(op_node) + state.add_edge(access_in, None, op_node, "X", sdfg.make_array_memlet("inp")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(2, 4).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion wouldn't produce a map + assert any(isinstance(n, dace.nodes.MapEntry) for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(inp=X) + + assert np.allclose(np.sqrt(X), result) + + +if __name__ == "__main__": + test_sqrt_expansion() diff --git a/tests/onnx/pure_expansions/test_expansions.py b/tests/onnx/pure_expansions/test_expansions.py new file mode 100644 index 0000000000..1c70738c92 --- /dev/null +++ b/tests/onnx/pure_expansions/test_expansions.py @@ -0,0 +1,561 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import copy +import numpy as np + +import dace +from dace import transformation, data as dt +from dace.libraries import blas +import dace.library + +import dace.libraries.onnx as donnx +from dace.transformation.onnx import expand_onnx_nodes + + +def assert_allclose(a, b, rtol=1e-5, atol=1e-8): + np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.onnx +@pytest.mark.parametrize("a_shape, b_shape", [([2, 4], [4, 3])]) +def test_matmul_expansion(a_shape, b_shape): + blas.Gemm.default_implementation = "pure" + sdfg = dace.SDFG("test_matmul_expansion") + + X = np.random.rand(*a_shape).astype(np.float32) + Z = np.random.rand(*b_shape).astype(np.float32) + expected_result = X @ Z + sdfg.add_array("X", a_shape, dace.float32) + sdfg.add_array("Z", b_shape, dace.float32) + sdfg.add_array("__return", expected_result.shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_Z = state.add_access("Z") + access_result = state.add_access("__return") + + op_node = donnx.ONNXMatMul("Matmul") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "A", sdfg.make_array_memlet("X")) + state.add_edge(access_Z, None, op_node, "B", sdfg.make_array_memlet("Z")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + with dace.library.change_default(blas, "pure"): + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X, Z=Z) + + assert_allclose(expected_result, result) + + +@pytest.mark.onnx +def test_cast_int_to_float(): + sdfg = dace.SDFG("test_cast_int_to_float") + + sdfg.add_array("X", [2, 4], dace.int32) + sdfg.add_array("__return", [2, 4], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.float32) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.randint(0, 10, size=(2, 4), dtype=np.int32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.float32), result) + + +@pytest.mark.onnx +def test_cast_float_to_int(): + sdfg = dace.SDFG("test_cast_float_to_int") + + sdfg.add_array("X", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.int32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.int32) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.normal(scale=10, size=(2, 4)).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.int32), result) + + +@pytest.mark.onnx +def test_cast_float_to_long(): + sdfg = dace.SDFG("test_cast_float_to_long") + + sdfg.add_array("X", [2, 4], dace.float32) + sdfg.add_array("__return", [2, 4], dace.int64) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXCast("Cast") + op_node.to = donnx.converters.typeclass_to_onnx_tensor_type_int(dace.int64) + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.normal(scale=10, size=(2, 4)).astype(np.float32) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(X.astype(np.int64), result) + + +@pytest.mark.onnx +#+yapf: disable +@pytest.mark.parametrize("reduce_type, keepdims, axes", + [('Sum', True, [0]), + ('Sum', False, [-1]), + ('Sum', True, [0, -1]), + ('Max', False, [0, -1]), + ('Max', True, [0]), + ('Max', True, [-1]), + ('Mean', True, [-1]), + ('Mean', True, [0, -1]), + ('Mean', False, [0])]) +#+yapf: enable +def test_reduce(keepdims, reduce_type, axes): + + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reduce") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + + numpy_func = getattr(np, reduce_type.lower()) + numpy_result = numpy_func(X.copy(), axis=tuple(axes), keepdims=keepdims) + + resulting_shape = numpy_result.shape + + sdfg.add_array("__return", resulting_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = getattr(donnx, "ONNXReduce" + reduce_type)("reduce") + op_node.axes = axes + op_node.keepdims = 1 if keepdims else 0 + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "reduced", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + result = sdfg(X=X) + + assert_allclose(numpy_result, result, rtol=1e-5, atol=1e-5) + + +@pytest.mark.onnx +def test_reduce_scalar(): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reduce_scalar") + + numpy_result = np.mean(X) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_scalar("Y", dace.float32, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_Y = state.add_access("Y") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReduceMean("mean") + op_node.keepdims = 0 + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "reduced", access_Y, None, sdfg.make_array_memlet("Y")) + + state.add_edge(access_Y, None, access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + result = sdfg(X=X) + + assert_allclose(numpy_result, result, rtol=1e-5, atol=1e-5) + + +@pytest.mark.onnx +@pytest.mark.parametrize("new_shape", [[8, 10], [80], [2, 40]]) +def test_reshape(new_shape): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_reshape") + + numpy_result = X.reshape(*new_shape) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("shape", [len(new_shape)], dace.int64) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_shape = state.add_access("shape") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReshape("reshape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + state.add_edge(access_shape, None, op_node, "shape", sdfg.make_array_memlet("shape")) + + state.add_edge(op_node, "reshaped", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + # we don't need shape anymore + del sdfg.arrays["shape"] + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_flatten(): + + new_shape = [2, 40] + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + sdfg = dace.SDFG("test_flatten") + + numpy_result = X.reshape(*new_shape) + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXFlatten("flatten") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "input", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "output", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_reciprocal(): + X = np.random.normal(scale=10, size=(2, 4, 10)).astype(np.float32) + + numpy_result = 1 / X + sdfg = dace.SDFG("test_reciprocal") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("__return", numpy_result.shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReciprocal("reciprocal") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "X", sdfg.make_array_memlet("X")) + + state.add_edge(op_node, "Y", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.expand_library_nodes() + + # check that the expansion worked. The default ORT expansion contains a Tasklet with suffix _onnx_code + assert not any( + isinstance(n, dace.nodes.Tasklet) and n.name.endswith("_onnx_code") for n, _ in sdfg.all_nodes_recursive()) + + result = sdfg(X=X) + + assert_allclose(numpy_result, result) + + +@pytest.mark.onnx +def test_einsum(): + + @dace.program + def test_einsum(A: dace.float64[5, 4, 3], B: dace.float64[3, 2]): + Y = dace.define_local([5, 4, 2], dace.float64) + donnx.ONNXEinsum(Inputs__0=A, Inputs__1=B, Output=Y, equation="bij, jk -> bik") + return Y + + sdfg = test_einsum.to_sdfg() + expand_onnx_nodes(sdfg) + assert any(isinstance(n, blas.Gemm) for n, _ in sdfg.all_nodes_recursive()) + + A = np.random.rand(5, 4, 3).astype(np.float64) + B = np.random.rand(3, 2).astype(np.float64) + result = test_einsum(A.copy(), B.copy()) + assert_allclose(result, np.einsum("bij ,jk -> bik", A, B)) + + +@pytest.mark.onnx +def test_reshape_add(): + + @dace.program + def add_reshape(inp: dace.float64[9], bias: dace.float64[3], target_shape: dace.int64[2]): + reshaped = dace.define_local([3, 3], dace.float64) + donnx.ONNXReshape(data=inp, shape=target_shape, reshaped=reshaped) + + return reshaped + bias + + sdfg: dace.SDFG = add_reshape.to_sdfg(simplify=False) + + sdfg.apply_transformations_repeated([transformation.interstate.StateFusion]) + + inp = np.arange(9).astype(np.float64) + bias = np.arange(3).astype(np.float64) + result = sdfg(inp=inp.copy(), bias=bias.copy(), target_shape=np.array([3, 3]).astype(np.int64)) + + assert_allclose(result, inp.reshape(3, 3) + bias) + + +@pytest.mark.onnx +@pytest.mark.parametrize("input_desc", [dace.float32[2, 3], dace.float32[1], dace.float32]) +def test_sum_arrays(input_desc): + + if isinstance(input_desc, dt.Array): + shape = input_desc.shape + else: + shape = [1] + + def prog(inp0: copy.deepcopy(input_desc), inp1: copy.deepcopy(input_desc), inp2: copy.deepcopy(input_desc)): + result = dace.define_local(shape, dace.float32) + donnx.ONNXSum(data_0__0=inp0, data_0__1=inp1, data_0__2=inp2, sum=result) + return result + + prog.__name__ = "test_sum_arrays" + prog = dace.program(prog) + + inputs = [np.random.randn(*shape).astype(np.float32) for _ in range(3)] + if not isinstance(input_desc, dt.Array): + inputs = [i[0] for i in inputs] + np_result = (inputs[0] + inputs[1]) + inputs[2] + result = prog(*inputs) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_shape(): + + @dace.program + def shape(inp: dace.float64[9, 5, 3]): + shp = dace.define_local([3], dace.int64) + donnx.ONNXShape(data=inp, shape=shp) + return shp + + sdfg: dace.SDFG = shape.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + inp = np.random.rand(9, 5, 3).astype(np.float64) + result = sdfg(inp=inp.copy()) + assert_allclose(result, [9, 5, 3]), result + + +@pytest.mark.onnx +def test_gather_onnx_1(): + # gather in ONNX operators.md + @dace.program + def gather(inp: dace.float64[3, 2], indices: dace.int64[2, 2]): + output = dace.define_local([2, 2, 2], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=0) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]) + indices = np.array([[0, 1], [1, 2]]) + result = sdfg(inp=data.copy(), indices=indices.copy()) + assert_allclose(result, data[indices]) + + +@pytest.mark.onnx +def test_gather_bert(): + # gather found at start of bert model + @dace.program + def gather(embs: dace.float64[64, 8], input_ids: dace.int64[8, 16]): + output = dace.define_local([8, 16, 8], dace.float64) + donnx.ONNXGather(data=embs, output=output, indices=input_ids, axis=0) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + embs = np.random.rand(64, 8).astype(np.float64) + input_ids = np.random.randint(low=0, high=64, size=(8, 16)).astype(np.int64) + result = sdfg(embs=embs.copy(), input_ids=input_ids.copy()) + assert_allclose(result, embs[input_ids]) + + +@pytest.mark.onnx +def test_gather_scalar(): + # gather test 2 in BERT model (third last op) + @dace.program + def gather(inp: dace.float64[1, 8, 32], indices: dace.int64): + output = dace.define_local([1, 32], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=1) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.random.rand(1, 8, 32) + indices = np.int64(5) + result = sdfg(inp=data.copy(), indices=indices.copy()) + np_result = np.take(data, indices, axis=1) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_gather_onnx_2(): + # gather test 2 in ONNX operators.md + @dace.program + def gather(inp: dace.float64[3, 3], indices: dace.int64[1, 2]): + output = dace.define_local([3, 1, 2], dace.float64) + donnx.ONNXGather(data=inp, output=output, indices=indices, axis=1) + return output + + sdfg: dace.SDFG = gather.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + data = np.array([ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ]) + indices = np.array([[0, 2]]) + result = sdfg(inp=data.copy(), indices=indices.copy()) + np_result = np.take(data, indices, axis=1) + + assert_allclose(result, np_result) + + +@pytest.mark.onnx +def test_unsqueeze(): + + @dace.program + def unsqueeze(inp: dace.float64[3, 3]): + output = dace.define_local([3, 1, 3, 1], dace.float64) + axes = dace.define_local([2], dace.int64) + axes[0] = 1 + axes[1] = 3 + donnx.ONNXUnsqueeze(data=inp, expanded=output, axes=axes) + return output + + sdfg: dace.SDFG = unsqueeze.to_sdfg() + + data = np.array([ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ]) + + np_result = np.reshape(data, [3, 1, 3, 1]) + + result = sdfg(inp=data.copy()) + assert result.shape == (3, 1, 3, 1) + assert_allclose(result, np_result) + + +if __name__ == "__main__": + test_matmul_expansion(a_shape=[2, 4], b_shape=[4, 3]) + test_cast_int_to_float() + test_cast_float_to_int() + test_cast_float_to_long() + + reduce_params = [(True, 'Sum', [0]), (False, 'Sum', [-1]), (True, 'Sum', [0, -1]), (False, 'Max', [0, -1]), + (True, 'Max', [0]), (True, 'Max', [-1]), (True, 'Mean', [-1]), (True, 'Mean', [0, -1]), + (False, 'Mean', [0])] + for keepdims, reduce_type, axes in reduce_params: + test_reduce(keepdims=keepdims, reduce_type=reduce_type, axes=axes) + + test_reduce_scalar() + + for new_shape in [[8, 10], [80], [2, 40]]: + test_reshape(new_shape=new_shape) + + test_flatten() + test_reciprocal() + test_einsum() + test_reshape_add() + + for input_desc in [dace.float32[2, 3], dace.float32[1], dace.float32]: + test_sum_arrays(input_desc=input_desc) + + test_shape() + test_gather_onnx_1() + test_gather_bert() + test_gather_scalar() + test_gather_onnx_2() + test_unsqueeze() diff --git a/tests/onnx/test_bert_subgraphs.py b/tests/onnx/test_bert_subgraphs.py new file mode 100644 index 0000000000..8d3d90a13f --- /dev/null +++ b/tests/onnx/test_bert_subgraphs.py @@ -0,0 +1,106 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Regression tests for BERT subgraphs +""" +import numpy as np +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +from onnx import helper, numpy_helper, TensorProto +import torch +from dace.ml import ONNXModel + + +def make_slice_model(): + """Create a simple ONNX model with a Slice operation.""" + data_input = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2]) + output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) + + starts = numpy_helper.from_array(np.array([0], dtype=np.int64), name='starts') + ends = numpy_helper.from_array(np.array([1], dtype=np.int64), name='ends') + axes = numpy_helper.from_array(np.array([0], dtype=np.int64), name='axes') + + slice_node = helper.make_node('Slice', inputs=['data', 'starts', 'ends', 'axes'], outputs=['output']) + + graph = helper.make_graph([slice_node], + 'slice_graph', + inputs=[data_input], + outputs=[output], + initializer=[starts, ends, axes]) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 12)]) + model.ir_version = 7 + return model + + +def make_reshape_model(): + """Create an ONNX model simulating BERT embedding reshape operations.""" + output = helper.make_tensor_value_info('bert/embeddings/Reshape_4:0', TensorProto.FLOAT, [1, 256, 768]) + + position_embeddings = numpy_helper.from_array(np.random.randn(512, 768).astype(np.float32), + name='bert/embeddings/position_embeddings:0') + slice_starts = numpy_helper.from_array(np.array([0, 0], dtype=np.int64), name='const_slice__40') + slice_ends = numpy_helper.from_array(np.array([256, 2147483647], dtype=np.int64), name='const_slice__41') + reshape_shape = numpy_helper.from_array(np.array([1, 256, 768], dtype=np.int32), + name='bert/embeddings/Reshape_4/shape:0') + + slice_node = helper.make_node( + 'Slice', + inputs=['bert/embeddings/position_embeddings:0', 'const_slice__40', 'const_slice__41'], + outputs=['bert/embeddings/Slice:0']) + + cast_node = helper.make_node('Cast', + inputs=['bert/embeddings/Reshape_4/shape:0'], + outputs=['bert/embeddings/Reshape_4__42:0'], + to=TensorProto.INT64) + + reshape_node = helper.make_node('Reshape', + inputs=['bert/embeddings/Slice:0', 'bert/embeddings/Reshape_4__42:0'], + outputs=['bert/embeddings/Reshape_4:0']) + + graph = helper.make_graph([slice_node, cast_node, reshape_node], + 'reshape_graph', + inputs=[], + outputs=[output], + initializer=[position_embeddings, slice_starts, slice_ends, reshape_shape]) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 12)]) + model.ir_version = 7 + return model + + +@pytest.mark.onnx +def test_slice(): + model = make_slice_model() + dace_model = ONNXModel("test_slice", model, onnx_simplify=False) + + data = torch.ones(2) + + out = dace_model(data=data) + assert out.shape == (1, ), f"Expected output shape (1,), got {out.shape}" + assert out[0] == 1.0, f"Expected output value 1.0, got {out[0]}" + + +@pytest.mark.onnx +def test_reshape(): + model = make_reshape_model() + dace_model = ONNXModel("test_reshape", model) + dace_model() + + +@pytest.mark.onnx +def test_save_transients(): + model = make_reshape_model() + transients = {} + dace_model = ONNXModel("test_save_transients", model, save_transients=transients) + dace_model() + assert torch.allclose(transients["bertSLASHembeddingsSLASHReshape_4COLON0"].cpu(), + dace_model.weights["bert/embeddings/Reshape_4:0"]) + + +if __name__ == "__main__": + test_slice() + test_reshape() + test_save_transients() diff --git a/tests/onnx/test_input_outputs.py b/tests/onnx/test_input_outputs.py new file mode 100644 index 0000000000..bc979a7ac8 --- /dev/null +++ b/tests/onnx/test_input_outputs.py @@ -0,0 +1,229 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Testing input and output combinations for onnx Ops + +| Output / Input | Array CPU | +|----------------+-----------| +| Scalar CPU | Shape | +| Array CPU | Add | +""" +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") + +import numpy as np +import pytest + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_squeeze(simplify: bool): + + sdfg = dace.SDFG("test_squeeze") + + sdfg.add_array("X_arr", [1], dace.float32) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_scalar("scalar", dace.float32, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_axes = state.add_access("axes") + access_scalar = state.add_access("scalar") + + access_result = state.add_access("__return") + + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + op_node = donnx.ONNXSqueeze("Squeeze") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + + state.add_edge(op_node, "squeezed", access_scalar, None, sdfg.make_array_memlet("scalar")) + + unsqueeze_op = donnx.ONNXUnsqueeze("Unsqueeze") + state.add_node(unsqueeze_op) + state.add_edge(access_scalar, None, unsqueeze_op, "data", sdfg.make_array_memlet("scalar")) + state.add_edge(access_axes, None, unsqueeze_op, "axes", sdfg.make_array_memlet("axes")) + state.add_edge(unsqueeze_op, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(1).astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + sdfg.expand_library_nodes() + result = sdfg(X_arr=X) + + assert result.shape == (1, ), f"Expected shape (1,), got {result.shape}" + assert result[0] == X, f"Expected value {X}, got {result[0]}" + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_shape(simplify: bool): + sdfg = dace.SDFG("test_shape") + + sdfg.add_array("X_arr", [2, 4], dace.float32) + sdfg.add_array("__return", [2], dace.int64) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXShape("Shape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + + state.add_edge(op_node, "shape", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.random.rand(2, 4).astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X) + + assert np.all(result == (2, 4)) + + +@pytest.mark.onnx +@pytest.mark.parametrize("simplify", [True, False]) +def test_unsqueeze(simplify: bool): + sdfg = dace.SDFG("test_unsqueeze") + + sdfg.add_scalar("X_arr", dace.float32) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_array("__return", [1], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_axes = state.add_access("axes") + + access_result = state.add_access("__return") + + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + op_node = donnx.ONNXUnsqueeze("Unsqueeze") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X_arr")) + state.add_edge(access_axes, None, op_node, "axes", sdfg.make_array_memlet("axes")) + + state.add_edge(op_node, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + X = np.float32(np.random.rand()) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X) + + assert result.shape == (1, ), f"Expected shape (1,), got {result.shape}" + assert X == result[0], f"Expected value {X}, got {result[0]}" + + +@pytest.mark.onnx +@pytest.mark.parametrize("scalars", [True, False]) +@pytest.mark.parametrize("simplify", [True, False]) +def test_add(scalars: bool, simplify: bool): + sdfg = dace.SDFG("test_add") + + if scalars: + sdfg.add_scalar("X_arr", dace.float32) + sdfg.add_scalar("W_arr", dace.float32) + sdfg.add_scalar("Z_arr", dace.float32, transient=True) + sdfg.add_array("axes", [1], dace.int64, transient=True) + sdfg.add_array("__return", [1], dace.float32) + else: + sdfg.add_array("X_arr", [2, 2], dace.float32) + sdfg.add_array("W_arr", [2, 2], dace.float32) + sdfg.add_array("__return", [2, 2], dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X_arr") + access_W = state.add_access("W_arr") + + if scalars: + access_Z = state.add_access("Z_arr") + access_axes = state.add_access("axes") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXAdd("Add") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "A", sdfg.make_array_memlet("X_arr")) + state.add_edge(access_W, None, op_node, "B", sdfg.make_array_memlet("W_arr")) + + if scalars: + state.add_edge(op_node, "C", access_Z, None, sdfg.make_array_memlet("Z_arr")) + else: + state.add_edge(op_node, "C", access_result, None, sdfg.make_array_memlet("__return")) + + if scalars: + # Tasklet to initialize axes + init_axes = state.add_tasklet("init_axes", + inputs={}, + outputs={"__axes": dace.pointer(dace.int64)}, + code="__axes[0] = 0;", + language=dace.Language.CPP) + + state.add_edge(init_axes, "__axes", access_axes, None, sdfg.make_array_memlet("axes")) + + unsqueeze_op = donnx.ONNXUnsqueeze("Unsqueeze") + state.add_node(unsqueeze_op) + state.add_edge(access_Z, None, unsqueeze_op, "data", sdfg.make_array_memlet("Z_arr")) + state.add_edge(access_axes, None, unsqueeze_op, "axes", sdfg.make_array_memlet("axes")) + state.add_edge(unsqueeze_op, "expanded", access_result, None, sdfg.make_array_memlet("__return")) + + shapes = [] if scalars else [2, 2] + X = np.random.rand(*shapes) + W = np.random.rand(*shapes) + if not scalars: + X = X.astype(np.float32) + W = W.astype(np.float32) + + if simplify: + sdfg.expand_library_nodes() + sdfg.simplify() + + result = sdfg(X_arr=X, W_arr=W) + + numpy_result = X + W + + assert np.allclose(result, numpy_result) + + +if __name__ == "__main__": + for simplify in [True, False]: + test_squeeze(simplify=simplify) + test_shape(simplify=simplify) + test_unsqueeze(simplify=simplify) + + for scalars in [True, False]: + for simplify in [True, False]: + test_add(scalars=scalars, simplify=simplify) diff --git a/tests/onnx/test_models/test_bert.py b/tests/onnx/test_models/test_bert.py new file mode 100644 index 0000000000..1c447ac0e3 --- /dev/null +++ b/tests/onnx/test_models/test_bert.py @@ -0,0 +1,92 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Test a full model including indexing and input preparation. The model also includes lots of symbolic dimensions. +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("onnxsim", reason="ONNX Simplifier not installed. Please install with: pip install dace[ml]") +pytest.importorskip("transformers", + reason="transformers not installed. Please install with: pip install dace[ml-testing]") +import os + +import onnx +import onnxsim +import pathlib +import urllib + +import torch +from transformers import BertTokenizer, BertModel + +import dace +import dace.libraries.onnx as donnx +from tests.utils import torch_tensors_close + + +def get_data_file(url, directory_name=None) -> str: + """ Get a data file from ``url``, cache it locally and return the local file path to it. + + :param url: the url to download from. + :param directory_name: an optional relative directory path where the file will be downloaded to. + :return: the path of the downloaded file. + """ + + data_directory = (pathlib.Path(dace.__file__).parent.parent / 'tests' / 'data') + + if directory_name is not None: + data_directory /= directory_name + + data_directory.mkdir(exist_ok=True, parents=True) + + file_name = os.path.basename(urllib.parse.urlparse(url).path) + file_path = str(data_directory / file_name) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + return file_path + + +@pytest.mark.xdist_group("large_ML_models") +@pytest.mark.onnx +def test_bert_full(): + bert_tiny_root = 'http://spclstorage.inf.ethz.ch/~rauscho/bert-tiny' + get_data_file(bert_tiny_root + "/config.json", directory_name='bert-tiny') + vocab = get_data_file(bert_tiny_root + "/vocab.txt", directory_name='bert-tiny') + bert_path = get_data_file(bert_tiny_root + "/bert-tiny.onnx", directory_name='bert-tiny') + get_data_file(bert_tiny_root + "/pytorch_model.bin", directory_name='bert-tiny') + model_dir = os.path.dirname(vocab) + + tokenizer = BertTokenizer.from_pretrained(vocab) + pt_model = BertModel.from_pretrained(model_dir) + + text = "[CLS] how are you today [SEP] dude [SEP]" + tokenized_text = tokenizer.tokenize(text) + indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + segment_ids = [0] * 6 + [1] * 2 + + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segment_ids]) + attention_mask = torch.ones(1, 8, dtype=torch.int64) + + model = onnx.load(bert_path) + # infer shapes + model, _ = onnxsim.simplify(model, + skip_fuse_bn=True, + input_shapes=dict(input_ids=tokens_tensor.shape, + token_type_ids=segments_tensors.shape, + attention_mask=attention_mask.shape)) + + dace_model = donnx.ONNXModel("test_bert_full", model, auto_merge=True) + + dace_output = dace_model(input_ids=tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_mask) + + output = pt_model(tokens_tensor, token_type_ids=segments_tensors, attention_mask=attention_mask) + + torch_tensors_close("output_0", output[0], dace_output[0]) + torch_tensors_close("output_1", output[1], dace_output[1]) + + +if __name__ == "__main__": + test_bert_full() diff --git a/tests/onnx/test_name_shadowing.py b/tests/onnx/test_name_shadowing.py new file mode 100644 index 0000000000..aed2be894e --- /dev/null +++ b/tests/onnx/test_name_shadowing.py @@ -0,0 +1,37 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") + +import dace + +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_shadowing(): + new_shape = [8, 10] + sdfg = dace.SDFG("test_shadowing") + + sdfg.add_array("X", [2, 4, 10], dace.float32) + sdfg.add_array("shape", [len(new_shape)], dace.int64) + sdfg.add_array("__return", new_shape, dace.float32) + + state = sdfg.add_state() + access_X = state.add_access("X") + access_shape = state.add_access("shape") + access_result = state.add_access("__return") + + op_node = donnx.ONNXReshape("reshape") + + state.add_node(op_node) + state.add_edge(access_X, None, op_node, "data", sdfg.make_array_memlet("X")) + state.add_edge(access_shape, None, op_node, "shape", sdfg.make_array_memlet("shape")) + + state.add_edge(op_node, "reshaped", access_result, None, sdfg.make_array_memlet("__return")) + + sdfg.compile() + + +if __name__ == "__main__": + test_shadowing() diff --git a/tests/onnx/test_onnx_return_scalars.py b/tests/onnx/test_onnx_return_scalars.py new file mode 100644 index 0000000000..bc0f9cfd19 --- /dev/null +++ b/tests/onnx/test_onnx_return_scalars.py @@ -0,0 +1,61 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import torch +import onnx + +from dace.libraries import onnx as donnx + + +@pytest.mark.onnx +def test_onnx_return_scalars(): + # Dace programs can't return scalars. + # this test checks that we correctly copy out the scalars using a size [1] array + + # we will have a single operator that computes the sum of a 1D tensor + X = onnx.helper.make_tensor_value_info('X', onnx.TensorProto.FLOAT, [5]) + + # Create axes constant with value 0 + axes_constant = onnx.helper.make_tensor( + name='axes', + data_type=onnx.TensorProto.INT64, + dims=[1], # Single element array + vals=[0] # Reduce along axis 0 + ) + + # return value is a scalar + Y = onnx.helper.make_tensor_value_info('Y', onnx.TensorProto.FLOAT, []) + + node_def = onnx.helper.make_node( + 'ReduceSum', + ['X', "axes"], + ['Y'], + keepdims=0, + ) + + graph_def = onnx.helper.make_graph( + [node_def], + 'test-scalar-return', + [X], # inputs + [Y], # outputs + [axes_constant] # initializers (constants) + ) + + model_def = onnx.helper.make_model(graph_def, ir_version=10, opset_imports=[onnx.helper.make_opsetid('', 13)]) + + onnx.checker.check_model(model_def) + + # now we can test the backend + dace_model = donnx.ONNXModel("test_onnx_return_scalars", model_def) + inp = torch.arange(5).type(torch.float32) + + result = dace_model(inp) + assert result.shape == (), f"Expected scalar shape (), got {result.shape}" + assert result[()] == 1 + 2 + 3 + 4, f"Expected sum 10, got {result[()]}" + + +if __name__ == "__main__": + test_onnx_return_scalars() diff --git a/tests/onnx/test_python_frontend.py b/tests/onnx/test_python_frontend.py new file mode 100644 index 0000000000..495e5cddc3 --- /dev/null +++ b/tests/onnx/test_python_frontend.py @@ -0,0 +1,31 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Test the python frontend of onnx nodes +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_matmul(): + + @dace + def matmul(inp1: dace.float32[5, 5], inp2: dace.float32[5, 3]): + out = dace.define_local([5, 3], dace.float32) + donnx.ONNXMatMul(A=inp1, B=inp2, Y=out) + return out + + A = np.random.normal(size=(5, 5)).astype(np.float32) + B = np.random.normal(size=(5, 3)).astype(np.float32) + result = matmul(inp1=A.copy(), inp2=B.copy()) + np.testing.assert_allclose(A @ B, result, atol=1e-5, rtol=1e-5, err_msg="MatMul output mismatch") + + +if __name__ == "__main__": + test_matmul() diff --git a/tests/onnx/test_shared_input_output.py b/tests/onnx/test_shared_input_output.py new file mode 100644 index 0000000000..1e53cdd326 --- /dev/null +++ b/tests/onnx/test_shared_input_output.py @@ -0,0 +1,111 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Batch Norm is the only op that has a shared name between inputs and outputs. Test that prepending "in_" and "out_" works +""" + +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +import dace +import dace.libraries.onnx as donnx +from dace.ml import DaceModule + +from tests.utils import torch_tensors_close + + +@pytest.mark.onnx +@pytest.mark.parametrize("training_mode", [True, False]) +def test_bn_standalone(training_mode: bool): + + if training_mode: + + @dace.program + def test_bn_standalone(X: dace.float32[8, 3, 32, + 32], scale: dace.float32[3], B: dace.float32[3], mean: dace.float32[3], + var: dace.float32[3], running_mean: dace.float32[3], running_var: dace.float32[3]): + Y = dace.define_local([8, 3, 32, 32], dace.float32) + donnx.ONNXBatchNormalization( + X=X, + scale=scale, + B=B, + input_mean=mean, + input_var=var, + Y=Y, + running_mean=running_mean, + running_var=running_var, + training_mode=True, + ) + return Y + else: + + @dace.program + def test_bn_standalone(X: dace.float32[8, 3, 32, 32], scale: dace.float32[3], B: dace.float32[3], + mean: dace.float32[3], var: dace.float32[3]): + + Y = dace.define_local([8, 3, 32, 32], dace.float32) + donnx.ONNXBatchNormalization(X=X, + scale=scale, + B=B, + input_mean=mean, + input_var=var, + Y=Y, + training_mode=training_mode) + return Y + + X = torch.randn(8, 3, 32, 32) + scale = torch.randn(3) + B = torch.randn(3) + mean = torch.randn(3) + var = torch.abs(torch.randn(3)) + X_torch, scale_torch, B_torch, mean_torch, var_torch = X.clone(), scale.clone(), B.clone(), mean.clone(), var.clone( + ) + if training_mode: + running_mean = np.zeros(3, dtype=np.float32) + running_var = np.ones(3, dtype=np.float32) + dace_result = test_bn_standalone(X, scale, B, mean, var, running_mean, running_var) + else: + dace_result = test_bn_standalone(X, scale, B, mean, var) + + pt_result = F.batch_norm(X_torch, mean_torch, var_torch, scale_torch, B_torch, training=training_mode) + torch_tensors_close("output", pt_result, torch.from_numpy(dace_result)) + + +@pytest.mark.onnx +def test_bn_in_import(): + + class Module(torch.nn.Module): + + def __init__(self): + super(Module, self).__init__() + self.bn = nn.BatchNorm2d(3, track_running_stats=True) + + def forward(self, x): + return self.bn(x) + + pt_module = Module() + pt_module.eval() + dace_module = Module() + dace_module.eval() + + dace_module.load_state_dict(pt_module.state_dict()) + + dace_module = DaceModule(dace_module, sdfg_name="test_bn_in_import") + + X = torch.randn(8, 3, 32, 32) + pt_result = pt_module(X) + dace_result = dace_module(X) + + torch_tensors_close("output", pt_result, dace_result) + + +if __name__ == "__main__": + for training_mode in [True, False]: + test_bn_standalone(training_mode=training_mode) + test_bn_in_import() diff --git a/tests/onnx/test_variadic.py b/tests/onnx/test_variadic.py new file mode 100644 index 0000000000..635af8b9ce --- /dev/null +++ b/tests/onnx/test_variadic.py @@ -0,0 +1,53 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +import numpy as np + +import dace +import dace.libraries.onnx as donnx + + +@pytest.mark.onnx +def test_sum(): + sdfg = dace.SDFG("test_sum") + + sdfg.add_array("A_arr", [2, 2], dace.float32) + sdfg.add_array("B_arr", [2, 2], dace.float32) + sdfg.add_array("C_arr", [2, 2], dace.float32) + sdfg.add_array("__return", [2, 2], dace.float32) + + state = sdfg.add_state() + access_A = state.add_access("A_arr") + access_B = state.add_access("B_arr") + access_C = state.add_access("C_arr") + + access_result = state.add_access("__return") + + op_node = donnx.ONNXSum("Sum") + + state.add_node(op_node) + for i in range(3): + op_node.add_in_connector("data_0__{}".format(i)) + state.add_edge(access_A, None, op_node, "data_0__0", sdfg.make_array_memlet("A_arr")) + state.add_edge(access_B, None, op_node, "data_0__1", sdfg.make_array_memlet("B_arr")) + state.add_edge(access_C, None, op_node, "data_0__2", sdfg.make_array_memlet("C_arr")) + + state.add_edge(op_node, "sum", access_result, None, sdfg.make_array_memlet("__return")) + + A = np.random.rand(2, 2).astype(np.float32) + B = np.random.rand(2, 2).astype(np.float32) + C = np.random.rand(2, 2).astype(np.float32) + + sdfg.validate() + + result = sdfg(A_arr=A, B_arr=B, C_arr=C) + + numpy_result = A + B + C + + assert np.allclose(result, + numpy_result), f"Variadic sum mismatch: max diff = {np.max(np.abs(result - numpy_result))}" + + +if __name__ == "__main__": + test_sum() diff --git a/tests/tensorflow/callback_test.py b/tests/tensorflow/callback_test.py index 3dc359aac8..01b706e765 100644 --- a/tests/tensorflow/callback_test.py +++ b/tests/tensorflow/callback_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_callback(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession input_image = tf.constant(0.69, tf.float64, [2, 2, 5, 5, 2]) conv_filter = tf.constant(0.01, tf.float64, [1, 1, 1, 2, 2]) diff --git a/tests/tensorflow/compile_test.py b/tests/tensorflow/compile_test.py index 6f58597a3c..1fd00d32ed 100644 --- a/tests/tensorflow/compile_test.py +++ b/tests/tensorflow/compile_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_compile(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession print('DaCe Tensorflow frontend compile API test') diff --git a/tests/tensorflow/conv_test.py b/tests/tensorflow/conv_test.py index d7c44a98c4..41a6e384d7 100644 --- a/tests/tensorflow/conv_test.py +++ b/tests/tensorflow/conv_test.py @@ -7,7 +7,7 @@ def test_conv(): import tensorflow as tf from tensorflow.python.ops import gen_nn_ops - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession inp_shape = [10, 10, 10, 10] filter_shape = [3, 3, 10, 3] strides = [1, 3, 3, 1] diff --git a/tests/tensorflow/fbn_test.py b/tests/tensorflow/fbn_test.py index d3373745fb..fcc5f56fde 100644 --- a/tests/tensorflow/fbn_test.py +++ b/tests/tensorflow/fbn_test.py @@ -7,7 +7,7 @@ def test_fused_batch_norm(): import tensorflow as tf from tensorflow.python.ops import gen_nn_ops - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession num_channels = 3 size = [8, 224, 224, num_channels] diff --git a/tests/tensorflow/ops_test.py b/tests/tensorflow/ops_test.py index e685572f64..a34e882967 100644 --- a/tests/tensorflow/ops_test.py +++ b/tests/tensorflow/ops_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_shapen(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession myshape = [69, 96, 666] num_inputs = 5 @@ -28,7 +28,7 @@ def test_shapen(): @pytest.mark.tensorflow def test_mean(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession shape = [10, 11, 12, 13] inp = tf.placeholder(tf.float64, shape) @@ -58,7 +58,7 @@ def test_mean(): @pytest.mark.tensorflow def test_addn(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession shape = [10, 11, 12, 13] inputs = [np.random.rand(*shape) for _ in range(10)] addn_test_0 = tf.add_n(inputs) @@ -81,7 +81,7 @@ def test_addn(): @pytest.mark.tensorflow def test_slice(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession t = tf.placeholder(tf.int32, [3, 2, 3]) b = tf.placeholder(tf.int32, [3]) s = tf.placeholder(tf.int32, [3]) diff --git a/tests/tensorflow/pool_test.py b/tests/tensorflow/pool_test.py index d9c1a8f4d4..b30b5f01fb 100644 --- a/tests/tensorflow/pool_test.py +++ b/tests/tensorflow/pool_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_pooling(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession size_in = [1, 112, 112, 3] # size_in = [4, 4, 4, 4] np.random.seed(0) diff --git a/tests/tensorflow/simple_test.py b/tests/tensorflow/simple_test.py index 92917936a4..31abdb8513 100644 --- a/tests/tensorflow/simple_test.py +++ b/tests/tensorflow/simple_test.py @@ -6,7 +6,7 @@ @pytest.mark.tensorflow def test_simple(): import tensorflow as tf - from dace.frontend.tensorflow import TFSession + from dace.frontend.ml.tensorflow import TFSession print('DaCe Tensorflow frontend test') A = np.random.rand(16, 16).astype(np.float32) diff --git a/tests/torch_forward/test_attn.py b/tests/torch_forward/test_attn.py new file mode 100644 index 0000000000..df673e057c --- /dev/null +++ b/tests/torch_forward/test_attn.py @@ -0,0 +1,39 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch + +from dace.ml import DaceModule + +from dace.transformation.dataflow import RedundantSecondArray +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_attn(use_cpp_dispatcher: bool): + B = 2 + H = 16 + P = 64 + N = P * H + SM, SN = 512, 512 + K, Q, V = [torch.randn([SM, B, N]), torch.randn([SN, B, N]), torch.randn([SM, B, N])] + ptmodel = torch.nn.MultiheadAttention(N, H, bias=False) + + pt_outputs = ptmodel(Q, K, V) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_attn_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher, + auto_optimize=False) + + dace_outputs = dace_model(Q, K, V) + + torch_tensors_close("outputs_0", pt_outputs[0], dace_outputs[0]) + torch_tensors_close("outputs_1", pt_outputs[1], dace_outputs[1]) + + +if __name__ == "__main__": + test_attn(use_cpp_dispatcher=True) + test_attn(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_conv2d.py b/tests/torch_forward/test_conv2d.py new file mode 100644 index 0000000000..89a21bc27e --- /dev/null +++ b/tests/torch_forward/test_conv2d.py @@ -0,0 +1,55 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import dace +from dace.ml import DaceModule + + +@pytest.mark.torch +def test_conv2d(use_cpp_dispatcher: bool): + + class Model(nn.Module): + + def __init__(self): + super(Model, self).__init__() + self.conv1 = nn.Conv2d(1, 4, 3) + self.conv2 = nn.Conv2d(4, 4, 3) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + ptmodel = Model() + x = torch.rand(1, 1, 8, 8) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + + @dace.ml.module(sdfg_name=f"test_conv2d_decorator_{dispatcher_suffix}") + class TestDecorator(Model): + pass + + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_conv2d_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher) + dace_output = dace_model(x) + + dace_model_decorated = TestDecorator() + dace_model_decorated(x) + + torch_output = ptmodel(x) + + np.testing.assert_allclose(torch_output.detach().numpy(), + dace_output.detach().numpy(), + atol=1e-06, + err_msg="Conv2d output mismatch between PyTorch and DaCe") + + +if __name__ == "__main__": + test_conv2d(use_cpp_dispatcher=True) + test_conv2d(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_cpp_extension.py b/tests/torch_forward/test_cpp_extension.py new file mode 100644 index 0000000000..c8b1b624f4 --- /dev/null +++ b/tests/torch_forward/test_cpp_extension.py @@ -0,0 +1,120 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import os + +import numpy as np +import torch +import torch.utils.cpp_extension +from dace.codegen import targets, compiler +from dace.codegen.codeobject import CodeObject +from torch import nn + +import dace +from dace.libraries.torch import PyTorch +from tests.utils import torch_tensors_close + +op_source = """ +#include +#include + +#include + +using torch::Tensor; +using torch::DeviceType; +using torch::autograd::tensor_list; +using torch::autograd::AutogradContext; + +Tensor myadd(const Tensor& self, const Tensor& other) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myops::myadd", "") + .typed(); + return op.call(self, other); +} + +TORCH_LIBRARY(myops, m) { + m.def("myadd(Tensor self, Tensor other) -> Tensor"); +} + +Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) { + TORCH_CHECK(self_.sizes() == other_.sizes()); + TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU); + TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU); + Tensor self = self_.contiguous(); + Tensor other = other_.contiguous(); + Tensor result = torch::empty(self.sizes(), self.options()); + const float* self_ptr = self.data_ptr(); + const float* other_ptr = other.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = self_ptr[i] + other_ptr[i]; + } + return result; +} + +class MyAddFunction : public torch::autograd::Function { + public: + static Tensor forward( + AutogradContext *ctx, torch::Tensor self, torch::Tensor other) { + at::AutoDispatchBelowADInplaceOrView g; + return myadd(self, other); + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) { + auto grad_output = grad_outputs[0]; + return {grad_output, grad_output}; + } +}; + +Tensor myadd_autograd(const Tensor& self, const Tensor& other) { + return MyAddFunction::apply(self, other)[0]; +} + +TORCH_LIBRARY_IMPL(myops, CPU, m) { + m.impl("myadd", myadd_cpu); +} + +TORCH_LIBRARY_IMPL(myops, Autograd, m) { + m.impl("myadd", myadd_autograd); +} +""" + + +@pytest.mark.torch +def test_extension(): + program = CodeObject("myadd", + op_source, + "cpp", + targets.cpu.CPUCodeGen, + "MyAddFunction", + environments={PyTorch.full_class_path()}) + + BUILD_PATH = os.path.join('.dacecache', "pt_extension") + compiler.generate_program_folder(None, [program], BUILD_PATH) + torch.utils.cpp_extension.load( + name='pt_extension', + sources=[os.path.join(BUILD_PATH, 'src', 'cpu', 'myadd.cpp')], + is_python_module=False, + ) + torch.ops.myops.myadd(torch.randn(32, 32), torch.rand(32, 32)) + + +@pytest.mark.torch +def test_module_with_constant(): + + @dace.ml.module(sdfg_name="test_module_with_constant") + class Module(nn.Module): + + def forward(self, x): + return x + 1 + + inp = torch.ones((5, 5)) + output = Module()(inp) + + torch_tensors_close("output", inp + 1, output.cpu()) + + +if __name__ == "__main__": + test_extension() + test_module_with_constant() diff --git a/tests/torch_forward/test_debug_transients.py b/tests/torch_forward/test_debug_transients.py new file mode 100644 index 0000000000..995b986b8d --- /dev/null +++ b/tests/torch_forward/test_debug_transients.py @@ -0,0 +1,36 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +import numpy as np + +import dace +from tests.utils import torch_tensors_close + + +@dace.ml.module(debug_transients=True, sdfg_name="test_debug_transients") +class Module(nn.Module): + + def forward(self, x): + y = x + 3 + return y * 5 + + +@pytest.mark.torch +def test_debug_transients(): + + module = Module() + + x = torch.rand(5, 5) + outputs = module(x) + output, y, y2 = outputs + + torch_tensors_close("output", (x + 3) * 5, output) + torch_tensors_close("y2", (x + 3) * 5, y2) + torch_tensors_close("y", x + 3, y) + + +if __name__ == "__main__": + test_debug_transients() diff --git a/tests/torch_forward/test_dlpack.py b/tests/torch_forward/test_dlpack.py new file mode 100644 index 0000000000..879417ea80 --- /dev/null +++ b/tests/torch_forward/test_dlpack.py @@ -0,0 +1,26 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import ctypes + +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import dace +import numpy as np + +from dace.libraries.torch.dlpack import array_to_torch_tensor + + +@pytest.mark.torch +def test_desc_to_dlpack(): + mydata = np.arange(6).reshape(2, 3).astype(np.float32) + + ptr = ctypes.c_void_p(mydata.__array_interface__["data"][0]) + tensor = array_to_torch_tensor(ptr, dace.float32[2, 3]) + np.testing.assert_allclose(tensor, mydata), "Initial DLPack tensor conversion failed" + mydata += 1 + np.testing.assert_allclose(tensor, mydata), "DLPack tensor does not share memory with numpy array" + + +if __name__ == "__main__": + test_desc_to_dlpack() diff --git a/tests/torch_forward/test_efficientnet_block.py b/tests/torch_forward/test_efficientnet_block.py new file mode 100644 index 0000000000..6401f561a6 --- /dev/null +++ b/tests/torch_forward/test_efficientnet_block.py @@ -0,0 +1,114 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +pytest.importorskip("efficientnet_pytorch", + reason="efficientnet_pytorch not installed. Please install with: pip install dace[ml-testing]") +import torch +import numpy as np +from dace.transformation.dataflow import TrivialMapElimination +from dace.transformation.interstate import HoistState +from efficientnet_pytorch import get_model_params +from efficientnet_pytorch.model import MBConvBlock + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +@pytest.mark.torch +def test_mbconv(use_cpp_dispatcher: bool): + + with torch.no_grad(): + dace_inputs = torch.rand(8, 32, 224, 224) + torch_inputs = torch.clone(dace_inputs) + + block_params, global_params = get_model_params("efficientnet-b0", {}) + + torch_model = MBConvBlock(block_params[0], global_params).eval() + torch_model.set_swish(memory_efficient=False) + dace_model = MBConvBlock(block_params[0], global_params).eval() + dace_model.set_swish(memory_efficient=False) + + # Get the DaceModule + sdfg_name = f"efficientnet_mbconv_{use_cpp_dispatcher}" + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, compile_torch_extension=use_cpp_dispatcher) + dace_model.model.load_state_dict(torch_model.state_dict()) + + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + assert dace_name == torch_name, f"Parameter name mismatch: {dace_name} != {torch_name}" + np.testing.assert_allclose(value, dace_value, err_msg=f"{dace_name} tensors do not match") + + dace_output = dace_model(dace_inputs) + + torch_output = torch_model(torch_inputs) + np.testing.assert_allclose(torch_output.detach(), + dace_output.detach(), + rtol=1e-3, + atol=1e-3, + err_msg="output tensors do not match") + + # check that the batch norm running means and so on are written out correctly + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + + assert dace_name == torch_name, f"Parameter name mismatch after inference: {dace_name} != {torch_name}" + if "num_batches_tracked" in dace_name: + # we don't update this parameter + continue + np.testing.assert_allclose(value, dace_value, err_msg=f"{dace_name} tensors do not match") + + +@pytest.mark.torch +def test_fast_mb(use_cpp_dispatcher: bool): + with torch.no_grad(): + dace_inputs = torch.rand(8, 32, 224, 224) + torch_inputs = torch.clone(dace_inputs) + + block_params, global_params = get_model_params("efficientnet-b0", {}) + + torch_model = MBConvBlock(block_params[0], global_params).eval() + torch_model.set_swish(memory_efficient=False) + dace_model = MBConvBlock(block_params[0], global_params).eval() + dace_model.set_swish(memory_efficient=False) + + # Get the DaceModule + sdfg_name = f"efficientnet_fast_mbconv_{use_cpp_dispatcher}" + dace_model = DaceModule(dace_model, sdfg_name=sdfg_name, compile_torch_extension=use_cpp_dispatcher) + dace_model.model.load_state_dict(torch_model.state_dict()) + + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + assert dace_name == torch_name, f"Parameter name mismatch: {dace_name} != {torch_name}" + torch_tensors_close(dace_name, value, dace_value) + + def fuse_everything(module: DaceModule): + sdfg = module.sdfg + + sdfg.apply_transformations_repeated(HoistState) + sdfg.apply_transformations_repeated(TrivialMapElimination) + sdfg.simplify() + + dace_model.append_post_onnx_hook("fuse_sg", fuse_everything) + + dace_output = dace_model(dace_inputs) + + torch_output = torch_model(torch_inputs) + torch_tensors_close("output", torch_output, dace_output, rtol=1e-3, atol=1e-3) + + # check that the batch norm running means and so on are written out correctly + for (dace_name, dace_value), (torch_name, value) in zip(dace_model.model.state_dict().items(), + torch_model.state_dict().items()): + + assert dace_name == torch_name, f"Parameter name mismatch after inference: {dace_name} != {torch_name}" + if "num_batches_tracked" in dace_name: + # we don't update this parameter + continue + torch_tensors_close(dace_name, value, dace_value) + + +if __name__ == "__main__": + test_mbconv(use_cpp_dispatcher=True) + test_mbconv(use_cpp_dispatcher=False) + test_fast_mb(use_cpp_dispatcher=True) + test_fast_mb(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_img_op_implementations.py b/tests/torch_forward/test_img_op_implementations.py new file mode 100644 index 0000000000..4ec120b0f9 --- /dev/null +++ b/tests/torch_forward/test_img_op_implementations.py @@ -0,0 +1,95 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +import numpy as np + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class CustomBatchNorm(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, running_mean, running_var, weight, bias, training, momentum, eps): + output = torch.nn.functional.batch_norm(x, running_mean, running_var, weight, bias, training, momentum, eps) + return output, running_mean, running_var + + @staticmethod + def symbolic(g, x, running_mean, running_var, weight, bias, training, momentum, eps): + outputs = g.op("BatchNormalization", + x, + weight, + bias, + running_mean, + running_var, + training_mode_i=int(training), + momentum_f=momentum, + epsilon_f=eps, + outputs=3) + y, new_running_mean, new_running_var = outputs + return y, new_running_mean, new_running_var + + +class BatchNorm2dMeanVar(nn.Module): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): + super(BatchNorm2dMeanVar, self).__init__() + self.bn = nn.BatchNorm2d(num_features, + eps=eps, + momentum=momentum, + affine=affine, + track_running_stats=track_running_stats) + + def forward(self, x): + return CustomBatchNorm.apply(x, self.bn.running_mean, self.bn.running_var, self.bn.weight, self.bn.bias, + self.bn.training, self.bn.momentum, self.bn.eps) + + +@pytest.mark.torch +def test_bn(): + + inputs = torch.rand(1, 64, 60, 60) + + # pytorch and onnx specification differ in the way they use momentum: + # pytorch_momentum = 1 - onnx_momentum + # to guarantee matching behavior, we set the momentum to 0.5 + + pt_model = BatchNorm2dMeanVar(64, momentum=0.5) + dace_model = BatchNorm2dMeanVar(64, momentum=0.5) + pt_model.train() + dace_model.train() + + dace_model.load_state_dict(pt_model.state_dict()) + + dace_model = DaceModule(dace_model, sdfg_name="test_bn", training=True) + dace_output, dace_mean, dace_var = dace_model(inputs) + pt_output, pt_mean, pt_var = pt_model(inputs) + + torch_tensors_close("output", pt_output, dace_output) + torch_tensors_close("mean", pt_mean, dace_mean) + torch_tensors_close("var", pt_var, dace_var) + + +@pytest.mark.torch +def test_global_avg_pool(): + inputs = torch.rand(1, 64, 60, 60) + + pt_model = nn.AdaptiveAvgPool2d(1) + dace_model = nn.AdaptiveAvgPool2d(1) + + # Note: AdaptiveAvgPool2d has no parameters, but load_state_dict ensures compatibility + dace_model.load_state_dict(pt_model.state_dict()) + + dace_model = DaceModule(dace_model, sdfg_name="test_global_avg_pool", training=True) + dace_output = dace_model(inputs) + pt_output = pt_model(inputs) + + torch_tensors_close("output", pt_output, dace_output) + + +if __name__ == "__main__": + test_bn() + test_global_avg_pool() diff --git a/tests/torch_forward/test_lenet.py b/tests/torch_forward/test_lenet.py new file mode 100644 index 0000000000..e6db37f740 --- /dev/null +++ b/tests/torch_forward/test_lenet.py @@ -0,0 +1,60 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("onnx", reason="ONNX not installed. Please install with: pip install dace[ml]") +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +from dace.ml import DaceModule + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from tests.utils import torch_tensors_close + + +class LeNet(nn.Module): + + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(1, 6, (3, 3)) + self.conv2 = nn.Conv2d(6, 16, (3, 3)) + self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.max_pool2d(F.relu(self.conv1(x)), 2) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + + x = x.view(-1, 16 * 6 * 6) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + x = F.log_softmax(x, dim=1) + return x + + +@pytest.mark.torch +def test_lenet(use_cpp_dispatcher: bool): + + input = torch.rand(8, 1, 32, 32, dtype=torch.float32) + + net = LeNet() + dace_net = LeNet() + dace_net.load_state_dict(net.state_dict()) + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_net = DaceModule(dace_net, + sdfg_name=f"test_lenet_{dispatcher_suffix}", + compile_torch_extension=use_cpp_dispatcher) + + torch_output = net(torch.clone(input)) + dace_output = dace_net(torch.clone(input)) + dace_net.sdfg.expand_library_nodes() + + torch_tensors_close("output", torch_output, dace_output) + + +if __name__ == "__main__": + test_lenet(use_cpp_dispatcher=True) + test_lenet(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_module_dace_program.py b/tests/torch_forward/test_module_dace_program.py new file mode 100644 index 0000000000..d4dfd0c363 --- /dev/null +++ b/tests/torch_forward/test_module_dace_program.py @@ -0,0 +1,63 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") + +import numpy as np +import torch +from torch import nn + +import dace + +from dace.ml import DaceModule +from tests.utils import tensors_close, torch_tensors_close + + +@pytest.mark.torch +def test_parse_forward_simple(): + torch_module = torch.nn.Sequential(torch.nn.Linear(12, 24), torch.nn.Linear(24, 2)) + dace_module = DaceModule(torch_module, sdfg_name='test_parse_forward_simple') + x = torch.randn(2, 12) + expected = torch_module(x) + result = dace_module(x) + + torch_tensors_close('output', expected, result) + + @dace + def train_step(y): + # output is potentially a gpu tensor + output = dace_module(y) + cpu = np.empty_like(output) + cpu[:] = output + return cpu.sum() + + result = train_step(x) + tensors_close('parsed', expected.sum(), result) + + +@pytest.mark.torch +def test_parse_forward_nested(): + + torch_module = torch.nn.Sequential(torch.nn.Sequential(torch.nn.Linear(12, 24), torch.nn.Linear(24, 2)), + nn.Softmax(dim=1)) + dace_module2 = DaceModule(torch_module, sdfg_name='test_parse_forward_nested') + x = torch.randn(2, 12) + expected = torch_module(x) + result = dace_module2(x) + + torch_tensors_close('output', expected, result) + + @dace + def train_step(y): + output = dace_module2(y) + cpu = np.empty_like(output) + cpu[:] = output + return cpu.sum() + + result = train_step(x) + tensors_close('parsed', expected.sum(), result) + + +if __name__ == "__main__": + test_parse_forward_simple() + test_parse_forward_nested() diff --git a/tests/torch_forward/test_multi_output.py b/tests/torch_forward/test_multi_output.py new file mode 100644 index 0000000000..070902b8da --- /dev/null +++ b/tests/torch_forward/test_multi_output.py @@ -0,0 +1,44 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn + +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class Model(nn.Module): + + def __init__(self, new_shape): + super(Model, self).__init__() + self.new_shape = new_shape + + def forward(self, x): + return x + 1, x + 2 + + +@pytest.mark.torch +def test_multiple_outputs(use_cpp_dispatcher: bool): + + ptmodel = Model([5, 5]) + x = torch.rand([25]) + + torch_outputs = ptmodel(torch.clone(x)) + + dispatcher_suffix = "cpp" if use_cpp_dispatcher else "ctypes" + dace_model = DaceModule(ptmodel, + sdfg_name=f"test_multi_output_{dispatcher_suffix}", + auto_optimize=False, + compile_torch_extension=use_cpp_dispatcher) + + dace_outputs = dace_model(x) + + torch_tensors_close("output_0", torch_outputs[0], dace_outputs[0]) + torch_tensors_close("output_1", torch_outputs[1], dace_outputs[1]) + + +if __name__ == "__main__": + test_multiple_outputs(use_cpp_dispatcher=True) + test_multiple_outputs(use_cpp_dispatcher=False) diff --git a/tests/torch_forward/test_reshape.py b/tests/torch_forward/test_reshape.py new file mode 100644 index 0000000000..ad48d6c453 --- /dev/null +++ b/tests/torch_forward/test_reshape.py @@ -0,0 +1,38 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import pytest + +pytest.importorskip("torch", reason="PyTorch not installed. Please install with: pip install dace[ml]") +import torch +from torch import nn +from dace.ml import DaceModule +from tests.utils import torch_tensors_close + + +class Model(nn.Module): + + def __init__(self, new_shape): + super(Model, self).__init__() + self.new_shape = new_shape + + def forward(self, x): + x = x.reshape(self.new_shape) + return x + + +@pytest.mark.torch +def test_reshape_module(): + + ptmodel = Model([5, 5]) + x = torch.rand([25]) + + torch_output = ptmodel(torch.clone(x)) + + dace_model = DaceModule(ptmodel, sdfg_name="test_reshape_module", auto_optimize=False, dummy_inputs=(x, )) + + dace_output = dace_model(x) + + torch_tensors_close("output", torch_output, dace_output) + + +if __name__ == "__main__": + test_reshape_module() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000..b4ddb24314 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,60 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import os +import typing +import urllib.request, urllib.parse +import pathlib +import pytest +import dace +import numpy as np + + +def get_data_file(url, directory_name=None) -> str: + """ Get a data file from ``url``, cache it locally and return the local file path to it. + + :param url: the url to download from. + :param directory_name: an optional relative directory path where the file will be downloaded to. + :return: the path of the downloaded file. + """ + + data_directory = (pathlib.Path(dace.__file__).parent.parent / 'tests' / 'data') + + if directory_name is not None: + data_directory /= directory_name + + data_directory.mkdir(exist_ok=True, parents=True) + + file_name = os.path.basename(urllib.parse.urlparse(url).path) + file_path = str(data_directory / file_name) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + return file_path + + +def tensors_close(name, expected, result, rtol=1e-5, atol=1e-5): + + def to_numpy(x): + if hasattr(x, 'detach'): + x = x.detach() + if hasattr(x, 'cpu'): + x = x.cpu() + if hasattr(x, 'numpy'): + x = x.numpy() + return x + + expected = to_numpy(expected) + result = to_numpy(result) + np.testing.assert_allclose(result, expected, rtol=rtol, atol=atol, err_msg=f'{name} not close') + + +def torch_tensors_close(name, torch_v, dace_v, rtol=1e-5, atol=1e-4): + """ + Assert that the two torch tensors are close. Prints a nice error string if not. + """ + # check that the device is correct + assert torch_v.device == dace_v.device, "Tensors are on different devices" + + torch_v = torch_v.detach().cpu().numpy() + dace_v = dace_v.detach().cpu().numpy() + np.testing.assert_allclose(dace_v, torch_v, rtol=rtol, atol=atol, err_msg=f'{name} not close') From c115df96845870c223bfe8fb94f08d9a930de68d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 4 Dec 2025 09:12:35 +0100 Subject: [PATCH 02/17] Updated Reloading Scheme for `ReloadableDLL` (#2218) Modified the reloading scheme used by `ReloadableDLL`. If the library (of the compiled SDFG) is already loaded, through another instance of `CompiledSDFG` then `ReloadableDLL` will copy the SDFG library and try to load that until it founds a name that is free. In ICON4Py we noticed that this leads sometime to a segmentation fault on Linux, but not on MacOS X. We traced the main issue down to the fact that `ReloadableDLL` created a copy of the SDFG library without checking if the new name is already used, instead the file is simply overwritten. The new scheme changes this slightly, in the following ways: - If the new name is already taken, then no copy is performed and the class tries to use that file, that already exists. - Instead of copying library `n - 1` to `n` it always makes a copy from the initial library. --------- Co-authored-by: Philipp Schaad --- dace/codegen/compiled_sdfg.py | 37 ++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index 733f0ba53c..b80273ded3 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -9,6 +9,7 @@ import warnings import tempfile import pickle +import pathlib import sys import numpy as np @@ -77,7 +78,8 @@ def is_loaded(self) -> bool: lib_cfilename = ctypes.c_wchar_p(self._library_filename) else: # As UTF-8 - lib_cfilename = ctypes.c_char_p(self._library_filename.encode('utf-8')) + tt = self._library_filename.encode('utf-8') + lib_cfilename = ctypes.c_char_p(tt) return self._stub.is_library_loaded(lib_cfilename) == 1 @@ -96,21 +98,39 @@ def load(self): # Check if library is already loaded is_loaded = True lib_cfilename = None + lib_filename = self._library_filename + counter = 0 while is_loaded: # Convert library filename to string according to OS if os.name == 'nt': # As UTF-16 - lib_cfilename = ctypes.c_wchar_p(self._library_filename) + lib_cfilename = ctypes.c_wchar_p(lib_filename) else: # As UTF-8 - lib_cfilename = ctypes.c_char_p(self._library_filename.encode('utf-8')) + lib_cfilename = ctypes.c_char_p(lib_filename.encode('utf-8')) + # Test if the library is loaded. is_loaded = self._stub.is_library_loaded(lib_cfilename) + if is_loaded == 1: warnings.warn(f'Library {self._library_filename} already loaded, renaming file') + + # The library is loaded, copy the _original_ library file to a new file + # and then try to load that. We only do the copy if the new new name is + # free. It seems that at least on LINUX there is some issue if we + # overwrite a file that already exists. + lib_filename = self._library_filename + f'_{counter}' + counter += 1 + if pathlib.Path(lib_filename).exists(): + assert pathlib.Path(lib_filename).is_file() + continue + + # The file name is not taken, so make a copy. There might be a race condition + # here in the presence of multiple processes. + # TODO: Investigate if we should switch to hardlinks if they are supported. try: - shutil.copyfile(self._library_filename, self._library_filename + '_') - self._library_filename += '_' + assert self._library_filename != lib_filename + shutil.copyfile(self._library_filename, lib_filename) except shutil.Error: raise cgx.DuplicateDLLError(f'Library {os.path.basename(self._library_filename)}' 'is already loaded somewhere else and cannot be unloaded. ' @@ -118,6 +138,7 @@ def load(self): # Actually load the library self._lib = ctypes.c_void_p(self._stub.load_library(lib_cfilename)) + self._library_filename = lib_filename if self._lib.value is None: # Try to understand why the library is not loading, if dynamic @@ -147,6 +168,12 @@ def __enter__(self, *args, **kwargs): def __exit__(self, *args, **kwargs): self.unload() + def __copy__(self): + raise RuntimeError(f'Can not copy ReloadableDLL({self._library_filename})') + + def __deepcopy__(self, memodict={}): + raise RuntimeError(f'Can not copy ReloadableDLL({self._library_filename})') + class CompiledSDFG(object): """ A compiled SDFG object that can be called through Python. From 076fd31fb6cb3e177986d3d6f817124e43def0e2 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Dec 2025 07:47:45 -0800 Subject: [PATCH 03/17] Update .coveragerc --- .coveragerc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.coveragerc b/.coveragerc index 62db6f887c..c82a90c324 100644 --- a/.coveragerc +++ b/.coveragerc @@ -19,14 +19,16 @@ exclude_lines = if False: if __name__ == .__main__.: pass + if TYPE_CHECKING: + if typing.TYPE_CHECKING: omit = # Omit files that cannot be tested dace/jupyter.py # Omit deprecated files - dace/frontend/tensorflow/__init__.py - dace/frontend/tensorflow/tensorflow.py - dace/frontend/tensorflow/winograd.py - dace/frontend/tensorflow/transformations/__init__.py - dace/frontend/tensorflow/transformations/redundant_array.py + dace/frontend/ml/tensorflow/__init__.py + dace/frontend/ml/tensorflow/tensorflow.py + dace/frontend/ml/tensorflow/winograd.py + dace/frontend/ml/tensorflow/transformations/__init__.py + dace/frontend/ml/tensorflow/transformations/redundant_array.py From e19e785fd7c2ab13d5b415904a58f0e443b80509 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Dec 2025 07:49:52 -0800 Subject: [PATCH 04/17] Modify codecov.yml to change ignored files and builds Updated ignored paths and build notification settings. --- codecov.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codecov.yml b/codecov.yml index 1f7e594398..49fbd61acd 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,6 +1,8 @@ ignore: - "dace/jupyter.py" # Omit files that cannot be tested - - "dace/frontend/tensorflow/**/*" # Omit deprecated files + - "dace/frontend/ml/tensorflow/**/*" # Omit deprecated files + - "samples/**/*" + - "tests/**/*" coverage: range: 40..90 @@ -18,6 +20,6 @@ coverage: codecov: notify: - after_n_builds: 18 + after_n_builds: 23 comment: false From 0efa622bed9329832bf1384e7ad7aa6c06e25b3a Mon Sep 17 00:00:00 2001 From: Afif <37773945+affifboudaoud@users.noreply.github.com> Date: Mon, 8 Dec 2025 15:12:44 +0100 Subject: [PATCH 05/17] Increase timeout for ML tests (#2243) Increased pytest timeout from 300 to 600 seconds. --- .github/workflows/ml-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ml-ci.yml b/.github/workflows/ml-ci.yml index b36890d562..d6ffc83a5d 100644 --- a/.github/workflows/ml-ci.yml +++ b/.github/workflows/ml-ci.yml @@ -53,7 +53,7 @@ jobs: else export DACE_optimizer_automatic_simplification=${{ matrix.simplify }} fi - pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=300 -v -m "(torch or onnx or autodiff) and not gpu" + pytest --cov-report=xml --cov=dace --tb=short --timeout_method thread --timeout=600 -v -m "(torch or onnx or autodiff) and not gpu" ./codecov - uses: codecov/codecov-action@v4 From 173de0b24d17f1134232c9e27c323f41a02e6371 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Mon, 8 Dec 2025 16:07:18 +0000 Subject: [PATCH 06/17] Refactor dace/data.py into dace/data/ package (#2245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Refactor `dace/data.py` into `dace/data/` package ### Summary This PR refactors the monolithic `dace/data.py` file into a modular `dace/data/` package with separate files for different functionality, improving code organization and maintainability. ### Changes - [x] **`dace/data/core.py`**: Core data descriptor classes (`Data`, `Scalar`, `Array`, `ContainerArray`, `Stream`, `Structure`, `View`, `Reference` and their subclasses) - [x] **`dace/data/tensor.py`**: Tensor/sparse tensor support (`Tensor`, `TensorIndex*` classes) - [x] **`dace/data/creation.py`**: Data descriptor creation functions (`create_datadescriptor`, `make_array_from_descriptor`, `make_reference_from_descriptor`) - [x] **`dace/data/ctypes_interop.py`**: Ctypes interoperability (`make_ctypes_argument`) - [x] **`dace/data/ml.py`**: ML-related descriptors (`ParameterArray`) - [x] **`dace/data/__init__.py`**: Re-exports all public API for backward compatibility - [x] **`dace/utils.py`**: Utility functions (`find_new_name`, `deduplicate`, `prod`) - [x] **`dace/properties.py`**: Updated to handle circular import gracefully - [x] **`dace/autodiff/library/library.py`**: Updated to import `ParameterArray` from the new location - [x] **Deleted** old `dace/data.py` file - [x] **Removed** `Number` and `ArrayLike` from `dace/data/__init__.py` (other places import directly) - [x] **Moved** `_prod` to `dace/utils.py` as `prod` (kept `_prod` export for backward compat) - [x] **Fixed** broken imports in `data_report.py`, `data_layout_tuner.py`, and `cutout.py` ### Backward Compatibility All public APIs are re-exported from `dace.data`, ensuring backward compatibility with existing code.
Original prompt > > ---- > > *This section details on the original issue you should resolve* > > Refactor `dace/data.py` > `data.py` is a monolithic file containing classes for core data containers (Data, Scalar, Array, Stream, View, Reference, and their subclasses `*{View, Reference}`; functionality to get data descriptors from arbitrary objects; derived objects for Tensors and sparse tensors; and other functions. > > This issue will be resolved once `data.py` is refactored to a `dace/data/*` folder, which will contain separate files for: > 1. core descriptor classes > 2. structures (the Structure class and similar functionality) > 3. tensors/sparse tensors > 4. descriptor creation > 5. ML-related data descriptors, such as parameter arrays (see `dace/autodiff/library/library.py`) > 6...N. Other functions and classes categorized by their semantic meaning. > > The code for `dace/data/*` will be refactored out of `data.py` (which should not exist at the end of this issue), `dtypes.py` (which may exist but be shorter), and other files that contain data descriptors (subclasses of Data/Array/Stream/Structure/View/Reference, such as ParameterArray. Try to find all such subclasses in the codebase barring tests/* and samples/*). > > Lastly, utility functions in `data.py` and `dtypes.py` (only those two files for this issue), such as `find_new_name` from data.py and `deduplicate` from dtypes.py, should find themselves in a new `dace/utils.py` file. > > ## Comments on the Issue (you are @copilot in this section) > > > >
- Fixes spcl/dace#2244 --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: tbennun <8348955+tbennun@users.noreply.github.com> --- dace/autodiff/library/library.py | 102 +- .../instrumentation/data/data_report.py | 6 +- dace/data/__init__.py | 110 + dace/{data.py => data/core.py} | 1773 ++++------------- dace/data/creation.py | 243 +++ dace/data/ctypes_interop.py | 133 ++ dace/data/ml.py | 113 ++ dace/data/tensor.py | 698 +++++++ dace/frontend/common/einsum.py | 2 +- dace/libraries/linalg/nodes/tensordot.py | 2 +- dace/libraries/mpi/nodes/gather.py | 2 +- dace/libraries/mpi/nodes/redistribute.py | 2 +- dace/libraries/mpi/nodes/scatter.py | 2 +- dace/optimization/data_layout_tuner.py | 9 +- dace/properties.py | 12 +- dace/sdfg/analysis/cutout.py | 12 +- .../dataflow/map_distribution.py | 2 +- dace/utils.py | 58 + 18 files changed, 1761 insertions(+), 1520 deletions(-) create mode 100644 dace/data/__init__.py rename dace/{data.py => data/core.py} (59%) create mode 100644 dace/data/creation.py create mode 100644 dace/data/ctypes_interop.py create mode 100644 dace/data/ml.py create mode 100644 dace/data/tensor.py create mode 100644 dace/utils.py diff --git a/dace/autodiff/library/library.py b/dace/autodiff/library/library.py index b5cc0e5d97..b5e2a60e98 100644 --- a/dace/autodiff/library/library.py +++ b/dace/autodiff/library/library.py @@ -18,106 +18,8 @@ from dace.sdfg.utils import in_edge_with_name from dace.transformation.passes.analysis import AccessSets - -@properties.make_properties -class ParameterArray(data.Array): - """ - An array for which a gradient can be computed. - """ - # since this can be None, this is not a DataProperty - gradient = properties.Property(dtype=str, desc="The corresponding gradient buffer", default=None, allow_none=True) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __repr__(self): - return "Parameter" + data.Array.__repr__(self) - - def add_gradient_buffer(self, sdfg: SDFG, name: str) -> str: - """ - Find or create a gradient buffer for the parameter in the given SDFG. - - :param sdfg: the SDFG containing the parameter - :param name: the name of the parameter - :return: the name of the gradient buffer - """ - - if self.gradient: - return self.gradient - - # First, check if this array already has a gradient buffer in a nested - # SDFG. This happens, for example when pytorch modules are used in the - # frontend. In that case: - # 1. the parser assembles the closure of the module, which adds - # descriptors for all the parameters and their gradients (if they - # are required). - # 2. A nested sdfg is added for the module, with those array names. - # 3. The DaceProgram will then pass these arrays in when the - # DaceProgram is called, using the names from the closure that - # match the names from the NestedSDFG - # 4. When parsing the backward nodes, we want the gradient buffers in - # the closure to match the gradient buffers that we pass in. Thus, - # we need to make sure that we use the same name as the NestedSDFG - # - # Note that we do not currently do any nesting beyond this level, - # because nested modules are converted to one SDFG. - - cands = set() - for state in sdfg.nodes(): - for node in state.nodes(): - if not isinstance(node, nodes.NestedSDFG): - continue - - nested_names = set() - - for edge in state.in_edges(node): - if edge.data.data == name: - nested_names.add(edge.dst_conn) - for edge in state.out_edges(node): - if edge.data.data == name: - nested_names.add(edge.dst_conn) - - for name in nested_names: - nested_desc = node.sdfg.arrays[name] - if isinstance(nested_desc, ParameterArray) and nested_desc.gradient: - cands.add(nested_desc.gradient) - - if len(cands) > 1: - raise ValueError("Multiple gradient buffers found for parameter " + name) - elif len(cands) == 1: - # we found a name of a gradient buffer in a nested SDFG: - # reuse the same name in the outer sdfg if there is a matching descriptor - grad_name = cands.pop() - if grad_name in sdfg.arrays: - self.gradient = grad_name - return grad_name - else: - grad_name = sdfg._find_new_name('gradient_' + name) - - # Create a gradient buffer for the array - grad_desc = copy.deepcopy(self) - grad_desc.__class__ = data.Array - grad_desc.transient = True - grad_name = sdfg.add_datadesc(grad_name, grad_desc, find_new_name=True) - self.gradient = grad_name - return grad_name - - @staticmethod - def make_parameter(sdfg: SDFG, name: str): - """ - Converts an existing array into a parameter, without copying. - - :param sdfg: the SDFG containing the array. - :param name: the name of the array. - """ - desc = sdfg.arrays[name] - if isinstance(desc, ParameterArray): - return - - new_desc = copy.deepcopy(desc) - new_desc.__class__ = ParameterArray - new_desc.gradient = None - sdfg.arrays[name] = new_desc +# Import ParameterArray from the data package for backward compatibility +from dace.data.ml import ParameterArray @dace.library.expansion diff --git a/dace/codegen/instrumentation/data/data_report.py b/dace/codegen/instrumentation/data/data_report.py index d944c916f3..c13fabae77 100644 --- a/dace/codegen/instrumentation/data/data_report.py +++ b/dace/codegen/instrumentation/data/data_report.py @@ -2,10 +2,14 @@ from dataclasses import dataclass import struct from typing import Any, Dict, List, Set, Tuple, Union +from numbers import Number import os from dace import dtypes, SDFG -from dace.data import ArrayLike, Number # Type hint +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore import numpy as np diff --git a/dace/data/__init__.py b/dace/data/__init__.py new file mode 100644 index 0000000000..4620474f01 --- /dev/null +++ b/dace/data/__init__.py @@ -0,0 +1,110 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data descriptors for DaCe. + +This package contains classes for describing data containers (arrays, scalars, streams, etc.) +that can be used in SDFGs. The classes in this package are used to specify the shape, type, +and storage location of data, as well as other properties that affect code generation. + +For backward compatibility, all classes and functions are re-exported from the top-level +`dace.data` module. +""" + +# Core data descriptors +from dace.data.core import ( + Data, + Scalar, + Array, + ContainerArray, + Stream, + Structure, + View, + Reference, + ArrayView, + StructureView, + ContainerView, + ArrayReference, + StructureReference, + ContainerArrayReference, +) + +# Import prod from utils and expose as _prod for backward compatibility +from dace.utils import prod as _prod + +# Tensor/sparse tensor support +from dace.data.tensor import ( + TensorIterationTypes, + TensorAssemblyType, + TensorIndex, + TensorIndexDense, + TensorIndexCompressed, + TensorIndexSingleton, + TensorIndexRange, + TensorIndexOffset, + Tensor, +) + +# Convenience aliases for tensor indices +Dense = TensorIndexDense +Compressed = TensorIndexCompressed +Singleton = TensorIndexSingleton +Range = TensorIndexRange +Offset = TensorIndexOffset + +# ML-related data descriptors +from dace.data.ml import ParameterArray + +# Descriptor creation and array creation from descriptors +from dace.data.creation import ( + create_datadescriptor, + make_array_from_descriptor, + make_reference_from_descriptor, +) + +# Ctypes interoperability +from dace.data.ctypes_interop import make_ctypes_argument + +# Import utility function from utils (for backward compatibility) +from dace.utils import find_new_name + +__all__ = [ + # Core classes + 'Data', + 'Scalar', + 'Array', + 'ContainerArray', + 'Stream', + 'Structure', + 'View', + 'Reference', + 'ArrayView', + 'StructureView', + 'ContainerView', + 'ArrayReference', + 'StructureReference', + 'ContainerArrayReference', + # Tensor support + 'TensorIterationTypes', + 'TensorAssemblyType', + 'TensorIndex', + 'TensorIndexDense', + 'TensorIndexCompressed', + 'TensorIndexSingleton', + 'TensorIndexRange', + 'TensorIndexOffset', + 'Tensor', + # Tensor aliases + 'Dense', + 'Compressed', + 'Singleton', + 'Range', + 'Offset', + # ML descriptors + 'ParameterArray', + # Functions + 'create_datadescriptor', + 'make_array_from_descriptor', + 'make_reference_from_descriptor', + 'make_ctypes_argument', + 'find_new_name', +] diff --git a/dace/data.py b/dace/data/core.py similarity index 59% rename from dace/data.py rename to dace/data/core.py index 6026b24f32..4c64ed7b78 100644 --- a/dace/data.py +++ b/dace/data/core.py @@ -1,15 +1,17 @@ # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -import aenum +""" +Core data descriptor classes. + +This module contains the base ``Data`` class and all core descriptor classes: +``Scalar``, ``Array``, ``ContainerArray``, ``Stream``, ``Structure``, +``View``, ``Reference``, and their subclasses. +""" import copy as cp import ctypes import dataclasses -import functools -import warnings -from abc import ABC, abstractmethod from collections import OrderedDict -from numbers import Number -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Set, Tuple, Union import numpy as np import sympy as sp @@ -19,146 +21,27 @@ except (ModuleNotFoundError, ImportError): ArrayLike = Any -from dace import config, dtypes, serialize, symbolic +from dace import dtypes, serialize, symbolic from dace.codegen import cppunparse from dace.properties import (DebugInfoProperty, DictProperty, EnumProperty, ListProperty, NestedDataClassProperty, OrderedDictProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties) +from dace.utils import prod +# Backward compatibility alias +_prod = prod -def create_datadescriptor(obj, no_custom_desc=False): - """ Creates a data descriptor from various types of objects. - :see: dace.data.Data - """ - if isinstance(obj, Data): - return obj - elif not no_custom_desc and hasattr(obj, '__descriptor__'): - return obj.__descriptor__() - elif not no_custom_desc and hasattr(obj, 'descriptor'): - return obj.descriptor - elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor": - # special case for torch tensors. Maybe __array__ could be used here for a more - # general solution, but torch doesn't support __array__ for cuda tensors. - try: - # If torch is importable, define translations between typeclasses and torch types. These are reused by daceml. - # conversion happens here in pytorch: - # https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237 - import torch - TYPECLASS_TO_TORCH_DTYPE = { - dtypes.bool_: torch.bool, - dtypes.int8: torch.int8, - dtypes.int16: torch.int16, - dtypes.int32: torch.int32, - dtypes.int64: torch.int64, - dtypes.uint8: torch.uint8, - dtypes.float16: torch.float16, - dtypes.float32: torch.float32, - dtypes.float64: torch.float64, - dtypes.complex64: torch.complex64, - dtypes.complex128: torch.complex128, - } - - TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()} - - storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default - - return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype], - strides=obj.stride(), - shape=tuple(obj.shape), - storage=storage) - except ImportError: - raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported") - elif dtypes.is_array(obj) and (hasattr(obj, '__array_interface__') or hasattr(obj, '__cuda_array_interface__')): - if dtypes.is_gpu_array(obj): - interface = obj.__cuda_array_interface__ - storage = dtypes.StorageType.GPU_Global - else: - interface = obj.__array_interface__ - storage = dtypes.StorageType.Default +def _arrays_to_json(arrays): + if arrays is None: + return None + return [(k, serialize.to_json(v)) for k, v in arrays.items()] - if hasattr(obj, 'dtype') and obj.dtype.fields is not None: # Struct - dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) - else: - if np.dtype(interface['typestr']).type is np.void: # Struct from __array_interface__ - if 'descr' in interface: - dtype = dtypes.struct('unnamed', **{ - k: dtypes.typeclass(np.dtype(v).type) - for k, v in interface['descr'] - }) - else: - raise TypeError(f'Cannot infer data type of array interface object "{interface}"') - else: - dtype = dtypes.typeclass(np.dtype(interface['typestr']).type) - itemsize = np.dtype(interface['typestr']).itemsize - if len(interface['shape']) == 0: - return Scalar(dtype, storage=storage) - return Array(dtype=dtype, - shape=interface['shape'], - strides=(tuple(s // itemsize for s in interface['strides']) if interface['strides'] else None), - storage=storage) - elif isinstance(obj, (list, tuple)): - # Lists and tuples are cast to numpy - obj = np.array(obj) - - if obj.dtype.fields is not None: # Struct - dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) - else: - dtype = dtypes.typeclass(obj.dtype.type) - return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape) - elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray": - # special case for CuPy and HIP, which does not support __cuda_array_interface__ - storage = dtypes.StorageType.GPU_Global - dtype = dtypes.typeclass(obj.dtype.type) - itemsize = obj.itemsize - return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage) - elif symbolic.issymbolic(obj): - return Scalar(symbolic.symtype(obj)) - elif isinstance(obj, dtypes.typeclass): - return Scalar(obj) - elif (obj is int or obj is float or obj is complex or obj is bool or obj is None): - return Scalar(dtypes.typeclass(obj)) - elif isinstance(obj, type) and issubclass(obj, np.number): - return Scalar(dtypes.typeclass(obj)) - elif isinstance(obj, (Number, np.number, np.bool_)): - return Scalar(dtypes.typeclass(type(obj))) - elif obj is type(None): - # NoneType is void * - return Scalar(dtypes.pointer(dtypes.typeclass(None))) - elif isinstance(obj, str) or obj is str: - return Scalar(dtypes.string) - elif callable(obj): - # Cannot determine return value/argument types from function object - return Scalar(dtypes.callback(None)) - - raise TypeError(f'Could not create a DaCe data descriptor from object {obj}. ' - 'If this is a custom object, consider creating a `__descriptor__` ' - 'adaptor method to the type hint or object itself.') - - -def _prod(sequence): - return functools.reduce(lambda a, b: a * b, sequence, 1) - - -def find_new_name(name: str, existing_names: Sequence[str]) -> str: - """ - Returns a name that matches the given ``name`` as a prefix, but does not - already exist in the given existing name set. The behavior is typically - to append an underscore followed by a unique (increasing) number. If the - name does not already exist in the set, it is returned as-is. - - :param name: The given name to find. - :param existing_names: The set of existing names. - :return: A new name that is not in existing_names. - """ - if name not in existing_names: - return name - cur_offset = 0 - new_name = name + '_' + str(cur_offset) - while new_name in existing_names: - cur_offset += 1 - new_name = name + '_' + str(cur_offset) - return new_name + +def _arrays_from_json(obj, context=None): + if obj is None: + return {} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) @make_properties @@ -345,910 +228,15 @@ def __matmul__(self, storage: dtypes.StorageType): Syntactic sugar for specifying the storage of a data descriptor. This enables controlling the storage location as follows: - .. code-block:: python - - @dace - def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): - return X + 1 - """ - new_desc = cp.deepcopy(self) - new_desc.storage = storage - return new_desc - - -def _arrays_to_json(arrays): - if arrays is None: - return None - return [(k, serialize.to_json(v)) for k, v in arrays.items()] - - -def _arrays_from_json(obj, context=None): - if obj is None: - return {} - return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) - - -@make_properties -class Structure(Data): - """ Base class for structures. """ - - members = OrderedDictProperty(default=OrderedDict(), - desc="Dictionary of structure members", - from_json=_arrays_from_json, - to_json=_arrays_to_json) - name = Property(dtype=str, desc="Structure type name") - - def __init__(self, - members: Union[Dict[str, Data], List[Tuple[str, Data]]], - name: str = 'Structure', - transient: bool = False, - storage: dtypes.StorageType = dtypes.StorageType.Default, - location: Dict[str, str] = None, - lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, - debuginfo: dtypes.DebugInfo = None): - - self.members = OrderedDict(members) - for k, v in self.members.items(): - if isinstance(v, dtypes.typeclass): - v = Scalar(v) - self.members[k] = v - v.transient = transient - - self.name = name - fields_and_types = OrderedDict() - symbols = set() - for k, v in self.members.items(): - if isinstance(v, Structure): - symbols |= v.free_symbols - fields_and_types[k] = (v.dtype, str(v.total_size)) - elif isinstance(v, Array): - symbols |= v.free_symbols - fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) - elif isinstance(v, Scalar): - symbols |= v.free_symbols - fields_and_types[k] = v.dtype - elif isinstance(v, dtypes.typeclass): - fields_and_types[k] = v - elif isinstance(v, (sp.Basic, symbolic.SymExpr)): - symbols |= v.free_symbols - fields_and_types[k] = symbolic.symtype(v) - elif isinstance(v, (int, np.integer)): - fields_and_types[k] = dtypes.typeclass(type(v)) - else: - raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") - - # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. - # NOTE: See discussion about data/object symbols. - # for s in symbols: - # if str(s) in fields_and_types: - # continue - # if hasattr(s, "dtype"): - # fields_and_types[str(s)] = s.dtype - # else: - # fields_and_types[str(s)] = dtypes.int32 - - dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) - dtype.base_type.__descriptor__ = self - shape = (1, ) - super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) - - @staticmethod - def from_json(json_obj, context=None): - if json_obj['type'] != 'Structure': - raise TypeError("Invalid data type") - - # Create dummy object - ret = Structure({}) - serialize.set_properties_from_json(ret, json_obj, context=context) - - return ret - - @staticmethod - def from_dataclass(cls, **overrides) -> 'Structure': - """ - Creates a Structure data descriptor from a dataclass instance. - - :param cls: The dataclass to convert. - :param overrides: Optional overrides for the structure fields. - :return: A Structure data descriptor. - """ - members = {} - for field in dataclasses.fields(cls): - # Recursive structures - if dataclasses.is_dataclass(field.type): - members[field.name] = Structure.from_dataclass(field.type) - continue - members[field.name] = field.type - - members.update(overrides) - return Structure(members, name=cls.__name__) - - @property - def total_size(self): - return -1 - - @property - def offset(self): - return [0] - - @property - def start_offset(self): - return 0 - - @property - def strides(self): - return [1] - - @property - def free_symbols(self) -> Set[symbolic.SymbolicType]: - """ Returns a set of undefined symbols in this data descriptor. """ - result = set() - for k, v in self.members.items(): - result |= v.free_symbols - return result - - def __repr__(self): - return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" - - def as_arg(self, with_types=True, for_call=False, name=None): - if self.storage is dtypes.StorageType.GPU_Global: - return Array(self.dtype, [1]).as_arg(with_types, for_call, name) - if not with_types or for_call: - return name - return self.dtype.as_arg(name) - - def __getitem__(self, s): - """ This is syntactic sugar that allows us to define an array type - with the following syntax: ``Structure[N,M]`` - :return: A ``data.ContainerArray`` data descriptor. - """ - if isinstance(s, list) or isinstance(s, tuple): - return ContainerArray(self, tuple(s)) - return ContainerArray(self, (s, )) - - # NOTE: Like Scalars? - @property - def may_alias(self) -> bool: - return False - - # TODO: Can Structures be optional? - @property - def optional(self) -> bool: - return False - - def keys(self): - result = self.members.keys() - for k, v in self.members.items(): - if isinstance(v, Structure): - result |= set(map(lambda x: f"{k}.{x}", v.keys())) - return result - - def clone(self): - return Structure(self.members, self.name, self.transient, self.storage, self.location, self.lifetime, - self.debuginfo) - - # NOTE: Like scalars? - @property - def pool(self) -> bool: - return False - - def make_argument(self, **fields) -> ctypes.Structure: - """ - Creates a structure instance from the given field values, which can be used as - an argument for DaCe programs. - - :param fields: Dictionary of field names to values. - :return: A ctypes Structure instance. - """ - struct_type: dtypes.struct = self.dtype.base_type - struct_ctype = struct_type.as_ctypes() - - def _make_arg(arg: Any, expected_type: Data, name: str) -> Any: - if isinstance(expected_type, Structure): - return ctypes.pointer(expected_type.make_argument_from_object(arg)) - return make_ctypes_argument(arg, expected_type, name) - - args = { - field_name: _make_arg(field_value, self.members[field_name], field_name) - for field_name, field_value in fields.items() if field_name in self.members - } - - struct_instance = struct_ctype(**args) - return struct_instance - - def make_argument_from_object(self, obj) -> ctypes.Structure: - """ - Creates a structure instance from the given object, which can be used as - an argument for DaCe programs. If the object has attributes matching the field names, - those attributes are used as field values. Other attributes are ignored. - - :param obj: Object containing field values. - :return: A ctypes Structure instance. - """ - return self.make_argument(**{field_name: getattr(obj, field_name) for field_name in self.members}) - - -class TensorIterationTypes(aenum.AutoNumberEnum): - """ - Types of tensor iteration capabilities. - - Value (Coordinate Value Iteration) allows to directly iterate over - coordinates such as when using the Dense index type. - - Position (Coordinate Position Iteratation) iterates over coordinate - positions, at which the actual coordinates lie. This is for example the case - with a compressed index, in which the pos array enables one to iterate over - the positions in the crd array that hold the actual coordinates. - """ - Value = () - Position = () - - -class TensorAssemblyType(aenum.AutoNumberEnum): - """ - Types of possible assembly strategies for the individual indices. - - NoAssembly: Assembly is not possible as such. - - Insert: index allows inserting elements at random (e.g. Dense) - - Append: index allows appending to a list of existing coordinates. Depending - on append order, this affects whether the index is ordered or not. This - could be changed by sorting the index after assembly - """ - NoAssembly = () - Insert = () - Append = () - - -class TensorIndex(ABC): - """ - Abstract base class for tensor index implementations. - """ - - @property - @abstractmethod - def iteration_type(self) -> TensorIterationTypes: - """ - Iteration capability supported by this index. - - See TensorIterationTypes for reference. - """ - pass - - @property - @abstractmethod - def locate(self) -> bool: - """ - True if the index supports locate (aka random access), False otw. - """ - pass - - @property - @abstractmethod - def assembly(self) -> TensorAssemblyType: - """ - What assembly type is supported by the index. - - See TensorAssemblyType for reference. - """ - pass - - @property - @abstractmethod - def full(self) -> bool: - """ - True if the level is full, False otw. - - A level is considered full if it encompasses all valid coordinates along - the corresponding tensor dimension. - """ - pass - - @property - @abstractmethod - def ordered(self) -> bool: - """ - True if the level is ordered, False otw. - - A level is ordered when all coordinates that share the same ancestor are - ordered by increasing value (e.g. in typical CSR). - """ - pass - - @property - @abstractmethod - def unique(self) -> bool: - """ - True if coordinate in the level are unique, False otw. - - A level is considered unique if no collection of coordinates that share - the same ancestor contains duplicates. In CSR this is True, in COO it is - not. - """ - pass - - @property - @abstractmethod - def branchless(self) -> bool: - """ - True if the level doesn't branch, false otw. - - A level is considered branchless if no coordinate has a sibling (another - coordinate with same ancestor) and all coordinates in parent level have - a child. In other words if there is a bijection between the coordinates - in this level and the parent level. An example of the is the Singleton - index level in the COO format. - """ - pass - - @property - @abstractmethod - def compact(self) -> bool: - """ - True if the level is compact, false otw. - - A level is compact if no two coordinates are separated by an unlabled - node that does not encode a coordinate. An example of a compact level - can be found in CSR, while the DIA formats range and offset levels are - not compact (they have entries that would coorespond to entries outside - the tensors index range, e.g. column -1). - """ - pass - - @abstractmethod - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - """ - Generates the fields needed for the index. - - :return: a Dict of fields that need to be present in the struct - """ - pass - - def to_json(self): - attrs = serialize.all_properties_to_json(self) - - retdict = {"type": type(self).__name__, "attributes": attrs} - - return retdict - - @classmethod - def from_json(cls, json_obj, context=None): - - # Selecting proper subclass - if json_obj['type'] == "TensorIndexDense": - self = TensorIndexDense.__new__(TensorIndexDense) - elif json_obj['type'] == "TensorIndexCompressed": - self = TensorIndexCompressed.__new__(TensorIndexCompressed) - elif json_obj['type'] == "TensorIndexSingleton": - self = TensorIndexSingleton.__new__(TensorIndexSingleton) - elif json_obj['type'] == "TensorIndexRange": - self = TensorIndexRange.__new__(TensorIndexRange) - elif json_obj['type'] == "TensorIndexOffset": - self = TensorIndexOffset.__new__(TensorIndexOffset) - else: - raise TypeError(f"Invalid data type, got: {json_obj['type']}") - - serialize.set_properties_from_json(self, json_obj['attributes'], context=context) - - return self - - -@make_properties -class TensorIndexDense(TensorIndex): - """ - Dense tensor index. - - Levels of this type encode the the coordinate in the interval [0, N), where - N is the size of the corresponding dimension. This level doesn't need any - index structure beyond the corresponding dimension size. - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Value - - @property - def locate(self) -> bool: - return True - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Insert - - @property - def full(self) -> bool: - return True - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return True - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return {} - - def __repr__(self) -> str: - s = "Dense" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexCompressed(TensorIndex): - """ - Tensor level that stores coordinates in segmented array. - - Levels of this type are compressed using a segented array. The pos array - holds the start and end positions of the segment in the crd (coordinate) - array that holds the child coordinates corresponding the parent. - """ - - _full = Property(dtype=bool, default=False) - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Append - - @property - def full(self) -> bool: - return self._full - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return True - - def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): - self._full = full - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length - f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Compressed" - - non_defaults = [] - if self._full: - non_defaults.append("F") - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexSingleton(TensorIndex): - """ - Tensor index that encodes a single coordinate per parent coordinate. - - Levels of this type hold exactly one coordinate for every coordinate in the - parent level. An example can be seen in the COO format, where every - coordinate but the first is encoded in this manner. - """ - - _full = Property(dtype=bool, default=False) - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.Append - - @property - def full(self) -> bool: - return self._full - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return True - - @property - def compact(self) -> bool: - return True - - def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): - self._full = full - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Singleton" - - non_defaults = [] - if self._full: - non_defaults.append("F") - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexRange(TensorIndex): - """ - Tensor index that encodes a interval of coordinates for every parent. - - The interval is computed from an offset for each parent together with the - tensor dimension size of this level (M) and the parent level (N) parents - corresponding tensor. Given the parent coordinate i, the level encodes the - range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Value - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.NoAssembly - - @property - def full(self) -> bool: - return False - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return False - - @property - def compact(self) -> bool: - return False - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Range" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class TensorIndexOffset(TensorIndex): - """ - Tensor index that encodes the next coordinates as offset from parent. - - Given a parent coordinate i and an offset index k, the level encodes the - coordinate j = i + offset[k]. - """ - - _ordered = Property(dtype=bool, default=False) - _unique = Property(dtype=bool, default=False) - - @property - def iteration_type(self) -> TensorIterationTypes: - return TensorIterationTypes.Position - - @property - def locate(self) -> bool: - return False - - @property - def assembly(self) -> TensorAssemblyType: - return TensorAssemblyType.NoAssembly - - @property - def full(self) -> bool: - return False - - @property - def ordered(self) -> bool: - return self._ordered - - @property - def unique(self) -> bool: - return self._unique - - @property - def branchless(self) -> bool: - return True - - @property - def compact(self) -> bool: - return False - - def __init__(self, ordered: bool = True, unique: bool = True): - self._ordered = ordered - self._unique = unique - - def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: - return { - f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length - } - - def __repr__(self) -> str: - s = "Offset" - - non_defaults = [] - if not self._ordered: - non_defaults.append("¬O") - if not self._unique: - non_defaults.append("¬U") - - if len(non_defaults) > 0: - s += f"({','.join(non_defaults)})" - - return s - - -@make_properties -class Tensor(Structure): - """ - Abstraction for Tensor storage format. - - This abstraction is based on [https://doi.org/10.1145/3276493]. - """ - - value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) - tensor_shape = ShapeProperty(default=[]) - indices = ListProperty(element_type=TensorIndex) - index_ordering = ListProperty(element_type=symbolic.SymExpr) - value_count = SymbolicProperty(default=0) - - def __init__(self, - value_dtype: dtypes.Typeclasses, - tensor_shape, - indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], - value_count: symbolic.SymExpr, - name: str, - transient: bool = False, - storage: dtypes.StorageType = dtypes.StorageType.Default, - location: Dict[str, str] = None, - lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, - debuginfo: dtypes.DebugInfo = None): - """ - Constructor for Tensor storage format. - - Below are examples of common matrix storage formats: - - .. code-block:: python - - M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) - - csr = dace.data.Tensor( - dace.float32, - (M, N), - [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], - nnz, - "CSR_Matrix", - ) - - csc = dace.data.Tensor( - dace.float32, - (M, N), - [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], - nnz, - "CSC_Matrix", - ) - - coo = dace.data.Tensor( - dace.float32, - (M, N), - [ - (dace.data.Compressed(unique=False), 0), - (dace.data.Singleton(), 1), - ], - nnz, - "CSC_Matrix", - ) - - num_diags = dace.symbol('num_diags') # number of diagonals stored - - diag = dace.data.Tensor( - dace.float32, - (M, N), - [ - (dace.data.Dense(), num_diags), - (dace.data.Range(), 0), - (dace.data.Offset(), 1), - ], - nnz, - "DIA_Matrix", - ) - - Below you can find examples of common 3rd order tensor storage formats: - - .. code-block:: python - - I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) - - coo = dace.data.Tensor( - dace.float32, - (I, J, K), - [ - (dace.data.Compressed(unique=False), 0), - (dace.data.Singleton(unique=False), 1), - (dace.data.Singleton(), 2), - ], - nnz, - "COO_3D_Tensor", - ) - - csf = dace.data.Tensor( - dace.float32, - (I, J, K), - [ - (dace.data.Compressed(), 0), - (dace.data.Compressed(), 1), - (dace.data.Compressed(), 2), - ], - nnz, - "CSF_3D_Tensor", - ) - - :param value_type: data type of the explicitly stored values. - :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) - :param indices: - a list of tuples, each tuple represents a level in the tensor - storage hirachy, specifying the levels tensor index type, and the - corresponding dimension this level encodes (as index of the - tensor_shape tuple above). The order of the dimensions may differ - from the logical shape of the tensor, e.g. as seen in the CSC - format. If an index's dimension is unrelated to the tensor shape - (e.g. in diagonal format where the first index's dimension is the - number of diagonals stored), a symbol can be specified instead. - :param value_count: number of explicitly stored values. - :param name: name of resulting struct. - :param others: See Structure class for remaining arguments - """ - - self.value_dtype = value_dtype - self.tensor_shape = tensor_shape - self.value_count = value_count - - indices, index_ordering = zip(*indices) - self.indices, self.index_ordering = list(indices), list(index_ordering) - - num_dims = len(tensor_shape) - dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] - - # all tensor dimensions must occure exactly once in indices - if not sorted(dimension_order) == list(range(num_dims)): - raise TypeError((f"All tensor dimensions must be refferenced exactly once in " - f"tensor indices. (referenced dimensions: {dimension_order}; " - f"tensor dimensions: {list(range(num_dims))})")) - - # assembling permanent and index specific fields - fields = dict( - order=Scalar(dtypes.int32), - dim_sizes=dtypes.int32[num_dims], - value_count=value_count, - values=dtypes.float32[value_count], - ) - - for (lvl, index) in enumerate(indices): - fields.update(index.fields(lvl, value_count)) - - super(Tensor, self).__init__(fields, name, transient, storage, location, lifetime, debuginfo) - - def __repr__(self): - return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" - - @staticmethod - def from_json(json_obj, context=None): - if json_obj['type'] != 'Tensor': - raise TypeError("Invalid data type") - - # Create dummy object - tensor = Tensor.__new__(Tensor) - serialize.set_properties_from_json(tensor, json_obj, context=context) + .. code-block:: python - return tensor + @dace + def add(X: dace.float32[10, 10] @ dace.StorageType.GPU_Global): + return X + 1 + """ + new_desc = cp.deepcopy(self) + new_desc.storage = storage + return new_desc @make_properties @@ -1727,6 +715,60 @@ def is_packed_c_strides(self) -> bool: return tuple(strides) == tuple(self.strides) +@make_properties +class ContainerArray(Array): + """ An array that may contain other data containers (e.g., Structures, other arrays). """ + + stype = NestedDataClassProperty(allow_none=True, default=None) + + def __init__(self, + stype: Data, + shape, + transient=False, + allow_conflicts=False, + storage=dtypes.StorageType.Default, + location=None, + strides=None, + offset=None, + may_alias=False, + lifetime=dtypes.AllocationLifetime.Scope, + alignment=0, + debuginfo=None, + total_size=None, + start_offset=None, + optional=None, + pool=False): + + self.stype = stype + if stype: + if isinstance(stype, Structure): + dtype = stype.dtype + else: + dtype = dtypes.pointer(stype.dtype) + else: + dtype = dtypes.pointer(dtypes.typeclass(None)) # void* + super(ContainerArray, + self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, may_alias, + lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + + @classmethod + def from_json(cls, json_obj, context=None): + # Create dummy object + ret = cls(None, ()) + serialize.set_properties_from_json(ret, json_obj, context=context) + + # Default shape-related properties + if not ret.offset: + ret.offset = [0] * len(ret.shape) + if not ret.strides: + # Default strides are C-ordered + ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] + if ret.total_size == 0: + ret.total_size = _prod(ret.shape) + + return ret + + @make_properties class Stream(Data): """ Stream (or stream array) data descriptor. """ @@ -1773,174 +815,322 @@ def from_json(cls, json_obj, context=None): ret = cls(dtypes.int8, 1) serialize.set_properties_from_json(ret, json_obj, context=context) - return ret + return ret + + def __repr__(self): + return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype, self.shape) + + @property + def total_size(self): + return _prod(self.shape) + + @property + def strides(self): + return [_prod(self.shape[i + 1:]) for i in range(len(self.shape))] + + @property + def start_offset(self): + return 0 + + @property + def optional(self) -> bool: + return False + + @property + def may_alias(self) -> bool: + return False + + def clone(self): + return type(self)(self.dtype, self.buffer_size, self.shape, self.transient, self.storage, self.location, + self.offset, self.lifetime, self.debuginfo) + + # Checks for equivalent shape and type + def is_equivalent(self, other): + if not isinstance(other, type(self)): + return False + + # Test type + if self.dtype != other.dtype: + return False + + # Test dimensionality + if len(self.shape) != len(other.shape): + return False + + # Test shape + for dim, otherdim in zip(self.shape, other.shape): + if dim != otherdim: + return False + return True + + def as_arg(self, with_types=True, for_call=False, name=None): + if not with_types or for_call: return name + if self.storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared]: + return 'dace::GPUStream<%s, %s> %s' % (str( + self.dtype.ctype), 'true' if sp.log(self.buffer_size, 2).is_Integer else 'false', name) + + return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name) + + def sizes(self): + return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] + + def size_string(self): + return (" * ".join([cppunparse.pyexpr2cpp(symbolic.symstr(s, cpp_mode=True)) for s in self.shape])) + + def is_stream_array(self): + return _prod(self.shape) != 1 + + def covers_range(self, rng): + if len(rng) != len(self.shape): + return False + + for s, (rb, re, rs) in zip(self.shape, rng): + # Shape has to be positive + if isinstance(s, sp.Basic): + olds = s + if 'positive' in s.assumptions0: + s = sp.Symbol(str(s), **s.assumptions0) + else: + s = sp.Symbol(str(s), positive=True, **s.assumptions0) + if isinstance(rb, sp.Basic): + rb = rb.subs({olds: s}) + if isinstance(re, sp.Basic): + re = re.subs({olds: s}) + if isinstance(rs, sp.Basic): + rs = rs.subs({olds: s}) + + try: + if rb < 0: # Negative offset + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), + # 'If this expression is false, please refine symbol definitions in the program.') + try: + if re > s: # Beyond shape + return False + except TypeError: # cannot determine truth value of Relational + pass + #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), + # 'If this expression is false, please refine symbol definitions in the program.') + + return True + + def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: + result = super().used_symbols(all_symbols) + if (self.transient or all_symbols) and isinstance(self.buffer_size, sp.Expr): + result |= set(self.buffer_size.free_symbols) + for o in self.offset: + if isinstance(o, sp.Expr): + result |= set(o.free_symbols) + + return result + + @property + def free_symbols(self): + return self.used_symbols(all_symbols=True) + + +@make_properties +class Structure(Data): + """ Base class for structures. """ + + members = OrderedDictProperty(default=OrderedDict(), + desc="Dictionary of structure members", + from_json=_arrays_from_json, + to_json=_arrays_to_json) + name = Property(dtype=str, desc="Structure type name") + + def __init__(self, + members: Union[Dict[str, Data], List[Tuple[str, Data]]], + name: str = 'Structure', + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + + self.members = OrderedDict(members) + for k, v in self.members.items(): + if isinstance(v, dtypes.typeclass): + v = Scalar(v) + self.members[k] = v + v.transient = transient + + self.name = name + fields_and_types = OrderedDict() + symbols = set() + for k, v in self.members.items(): + if isinstance(v, Structure): + symbols |= v.free_symbols + fields_and_types[k] = (v.dtype, str(v.total_size)) + elif isinstance(v, Array): + symbols |= v.free_symbols + fields_and_types[k] = (dtypes.pointer(v.dtype), str(_prod(v.shape))) + elif isinstance(v, Scalar): + symbols |= v.free_symbols + fields_and_types[k] = v.dtype + elif isinstance(v, dtypes.typeclass): + fields_and_types[k] = v + elif isinstance(v, (sp.Basic, symbolic.SymExpr)): + symbols |= v.free_symbols + fields_and_types[k] = symbolic.symtype(v) + elif isinstance(v, (int, np.integer)): + fields_and_types[k] = dtypes.typeclass(type(v)) + else: + raise TypeError(f"Attribute {k}'s value {v} has unsupported type: {type(v)}") + + # NOTE: We will not store symbols in the dtype for now, but leaving it as a comment to investigate later. + # NOTE: See discussion about data/object symbols. + # for s in symbols: + # if str(s) in fields_and_types: + # continue + # if hasattr(s, "dtype"): + # fields_and_types[str(s)] = s.dtype + # else: + # fields_and_types[str(s)] = dtypes.int32 + + dtype = dtypes.pointer(dtypes.struct(name, **fields_and_types)) + dtype.base_type.__descriptor__ = self + shape = (1, ) + super(Structure, self).__init__(dtype, shape, transient, storage, location, lifetime, debuginfo) + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Structure': + raise TypeError("Invalid data type") + + # Create dummy object + ret = Structure({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + @staticmethod + def from_dataclass(cls, **overrides) -> 'Structure': + """ + Creates a Structure data descriptor from a dataclass instance. + + :param cls: The dataclass to convert. + :param overrides: Optional overrides for the structure fields. + :return: A Structure data descriptor. + """ + members = {} + for field in dataclasses.fields(cls): + # Recursive structures + if dataclasses.is_dataclass(field.type): + members[field.name] = Structure.from_dataclass(field.type) + continue + members[field.name] = field.type - def __repr__(self): - return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype, self.shape) + members.update(overrides) + return Structure(members, name=cls.__name__) @property def total_size(self): - return _prod(self.shape) + return -1 @property - def strides(self): - return [_prod(self.shape[i + 1:]) for i in range(len(self.shape))] + def offset(self): + return [0] @property def start_offset(self): return 0 @property - def optional(self) -> bool: - return False + def strides(self): + return [1] @property - def may_alias(self) -> bool: - return False - - def clone(self): - return type(self)(self.dtype, self.buffer_size, self.shape, self.transient, self.storage, self.location, - self.offset, self.lifetime, self.debuginfo) - - # Checks for equivalent shape and type - def is_equivalent(self, other): - if not isinstance(other, type(self)): - return False - - # Test type - if self.dtype != other.dtype: - return False - - # Test dimensionality - if len(self.shape) != len(other.shape): - return False + def free_symbols(self) -> Set[symbolic.SymbolicType]: + """ Returns a set of undefined symbols in this data descriptor. """ + result = set() + for k, v in self.members.items(): + result |= v.free_symbols + return result - # Test shape - for dim, otherdim in zip(self.shape, other.shape): - if dim != otherdim: - return False - return True + def __repr__(self): + return f"{self.name} ({', '.join([f'{k}: {v}' for k, v in self.members.items()])})" def as_arg(self, with_types=True, for_call=False, name=None): - if not with_types or for_call: return name - if self.storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared]: - return 'dace::GPUStream<%s, %s> %s' % (str( - self.dtype.ctype), 'true' if sp.log(self.buffer_size, 2).is_Integer else 'false', name) - - return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name) - - def sizes(self): - return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] - - def size_string(self): - return (" * ".join([cppunparse.pyexpr2cpp(symbolic.symstr(s, cpp_mode=True)) for s in self.shape])) - - def is_stream_array(self): - return _prod(self.shape) != 1 - - def covers_range(self, rng): - if len(rng) != len(self.shape): - return False - - for s, (rb, re, rs) in zip(self.shape, rng): - # Shape has to be positive - if isinstance(s, sp.Basic): - olds = s - if 'positive' in s.assumptions0: - s = sp.Symbol(str(s), **s.assumptions0) - else: - s = sp.Symbol(str(s), positive=True, **s.assumptions0) - if isinstance(rb, sp.Basic): - rb = rb.subs({olds: s}) - if isinstance(re, sp.Basic): - re = re.subs({olds: s}) - if isinstance(rs, sp.Basic): - rs = rs.subs({olds: s}) + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return self.dtype.as_arg(name) - try: - if rb < 0: # Negative offset - return False - except TypeError: # cannot determine truth value of Relational - pass - #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0), - # 'If this expression is false, please refine symbol definitions in the program.') - try: - if re > s: # Beyond shape - return False - except TypeError: # cannot determine truth value of Relational - pass - #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s), - # 'If this expression is false, please refine symbol definitions in the program.') + def __getitem__(self, s): + """ This is syntactic sugar that allows us to define an array type + with the following syntax: ``Structure[N,M]`` + :return: A ``data.ContainerArray`` data descriptor. + """ + if isinstance(s, list) or isinstance(s, tuple): + return ContainerArray(self, tuple(s)) + return ContainerArray(self, (s, )) - return True + # NOTE: Like Scalars? + @property + def may_alias(self) -> bool: + return False - def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: - result = super().used_symbols(all_symbols) - if (self.transient or all_symbols) and isinstance(self.buffer_size, sp.Expr): - result |= set(self.buffer_size.free_symbols) - for o in self.offset: - if isinstance(o, sp.Expr): - result |= set(o.free_symbols) + # TODO: Can Structures be optional? + @property + def optional(self) -> bool: + return False + def keys(self): + result = self.members.keys() + for k, v in self.members.items(): + if isinstance(v, Structure): + result |= set(map(lambda x: f"{k}.{x}", v.keys())) return result - @property - def free_symbols(self): - return self.used_symbols(all_symbols=True) + def clone(self): + return Structure(self.members, self.name, self.transient, self.storage, self.location, self.lifetime, + self.debuginfo) + # NOTE: Like scalars? + @property + def pool(self) -> bool: + return False -@make_properties -class ContainerArray(Array): - """ An array that may contain other data containers (e.g., Structures, other arrays). """ + def make_argument(self, **fields) -> ctypes.Structure: + """ + Creates a structure instance from the given field values, which can be used as + an argument for DaCe programs. - stype = NestedDataClassProperty(allow_none=True, default=None) + :param fields: Dictionary of field names to values. + :return: A ctypes Structure instance. + """ + # Import here to avoid circular import + from dace.data.ctypes_interop import make_ctypes_argument + struct_type: dtypes.struct = self.dtype.base_type + struct_ctype = struct_type.as_ctypes() - def __init__(self, - stype: Data, - shape, - transient=False, - allow_conflicts=False, - storage=dtypes.StorageType.Default, - location=None, - strides=None, - offset=None, - may_alias=False, - lifetime=dtypes.AllocationLifetime.Scope, - alignment=0, - debuginfo=None, - total_size=None, - start_offset=None, - optional=None, - pool=False): + def _make_arg(arg: Any, expected_type: Data, name: str) -> Any: + if isinstance(expected_type, Structure): + return ctypes.pointer(expected_type.make_argument_from_object(arg)) + return make_ctypes_argument(arg, expected_type, name) - self.stype = stype - if stype: - if isinstance(stype, Structure): - dtype = stype.dtype - else: - dtype = dtypes.pointer(stype.dtype) - else: - dtype = dtypes.pointer(dtypes.typeclass(None)) # void* - super(ContainerArray, - self).__init__(dtype, shape, transient, allow_conflicts, storage, location, strides, offset, may_alias, - lifetime, alignment, debuginfo, total_size, start_offset, optional, pool) + args = { + field_name: _make_arg(field_value, self.members[field_name], field_name) + for field_name, field_value in fields.items() if field_name in self.members + } - @classmethod - def from_json(cls, json_obj, context=None): - # Create dummy object - ret = cls(None, ()) - serialize.set_properties_from_json(ret, json_obj, context=context) + struct_instance = struct_ctype(**args) + return struct_instance - # Default shape-related properties - if not ret.offset: - ret.offset = [0] * len(ret.shape) - if not ret.strides: - # Default strides are C-ordered - ret.strides = [_prod(ret.shape[i + 1:]) for i in range(len(ret.shape))] - if ret.total_size == 0: - ret.total_size = _prod(ret.shape) + def make_argument_from_object(self, obj) -> ctypes.Structure: + """ + Creates a structure instance from the given object, which can be used as + an argument for DaCe programs. If the object has attributes matching the field names, + those attributes are used as field values. Other attributes are ignored. - return ret + :param obj: Object containing field values. + :return: A ctypes Structure instance. + """ + return self.make_argument(**{field_name: getattr(obj, field_name) for field_name in self.members}) class View: @@ -2245,230 +1435,3 @@ def as_array(self): copy = cp.deepcopy(self) copy.__class__ = ContainerArray return copy - - -def make_array_from_descriptor(descriptor: Array, - original_array: Optional[ArrayLike] = None, - symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: - """ - Creates an array that matches the given data descriptor, and optionally copies another array to it. - - :param descriptor: The data descriptor to create the array from. - :param original_array: An optional array to fill the content of the return value with. - :param symbols: An optional symbol mapping between symbol names and their values. Used for creating arrays - with symbolic sizes. - :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides. - """ - symbols = symbols or {} - - free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() - if free_syms: - raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') - - if descriptor.storage == dtypes.StorageType.GPU_Global: - try: - import cupy as cp - except (ImportError, ModuleNotFoundError): - raise NotImplementedError('GPU memory can only be allocated in Python if cupy is installed') - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = cp.ndarray(shape=[total_size], dtype=dtype) - view = cp.ndarray(shape=shape, - dtype=dtype, - memptr=buffer.data, - strides=[s * dtype.itemsize for s in strides]) - return view - - def copy_array(dst, src): - dst[:] = cp.asarray(src) - - elif descriptor.storage == dtypes.StorageType.FPGA_Global: - raise TypeError('Cannot allocate FPGA array in Python') - else: - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = np.ndarray([total_size], dtype=dtype) - view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) - return view - - def copy_array(dst, src): - dst[:] = src - - # Make numpy array from data descriptor - npdtype = descriptor.dtype.as_numpy_dtype() - evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) - evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) - evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) - view = create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) - if original_array is not None: - copy_array(view, original_array) - - return view - - -def make_reference_from_descriptor(descriptor: Array, - original_array: ctypes.c_void_p, - symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: - """ - Creates an array that matches the given data descriptor from the given pointer. Shares the memory - with the argument (does not create a copy). - - :param descriptor: The data descriptor to create the array from. - :param original_array: The array whose memory the return value would be used in. - :param symbols: An optional symbol mapping between symbol names and their values. Used for referencing arrays - with symbolic sizes. - :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides, sharing memory - with the pointer specified in ``original_array``. - """ - symbols = symbols or {} - - original_array: int = ctypes.cast(original_array, ctypes.c_void_p).value - - free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() - if free_syms: - raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') - - if descriptor.storage == dtypes.StorageType.GPU_Global: - try: - import cupy as cp - except (ImportError, ModuleNotFoundError): - raise NotImplementedError('GPU memory can only be referenced in Python if cupy is installed') - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = dtypes.ptrtocupy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) - view = cp.ndarray(shape=shape, - dtype=dtype, - memptr=buffer.data, - strides=[s * dtype.itemsize for s in strides]) - return view - - elif descriptor.storage == dtypes.StorageType.FPGA_Global: - raise TypeError('Cannot reference FPGA array in Python') - else: - - def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: - buffer = dtypes.ptrtonumpy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) - view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) - return view - - # Make numpy array from data descriptor - npdtype = descriptor.dtype.as_numpy_dtype() - evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) - evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) - evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) - return create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) - - -def make_ctypes_argument(arg: Any, - argtype: Data, - name: Optional[str] = None, - allow_views: Optional[bool] = None, - symbols: Optional[Dict[str, Any]] = None, - callback_retval_references: Optional[List[Any]] = None) -> Any: - """ - Converts a given argument to the expected ``ctypes`` type for passing to compiled SDFG functions. - - :param arg: The argument to convert. - :param argtype: The expected data descriptor type of the argument. - :param name: The name of the argument (for error messages). - :param allow_views: Whether to allow views and references as input. If False, raises an error if a view or - reference is passed. If None (default), uses the global configuration setting - ``compiler.allow_view_arguments``. - :param symbols: An optional symbol mapping between symbol names and their values. Used for evaluating symbolic - sizes in callback arguments. - :param callback_retval_references: A list to store references to callback return values (to avoid garbage - collection of said return values). This object must be kept alive until the - SDFG call is complete. - :return: The argument converted to the appropriate ctypes type. - """ - if allow_views is None: - no_view_arguments = not config.Config.get_bool('compiler', 'allow_view_arguments') - else: - no_view_arguments = not allow_views - a = name or '' - atype = argtype - - result = arg - is_array = dtypes.is_array(arg) - is_ndarray = isinstance(arg, np.ndarray) - is_dtArray = isinstance(argtype, Array) - if not is_array and is_dtArray: - if isinstance(arg, list): - print(f'WARNING: Casting list argument "{a}" to ndarray') - elif arg is None: - if atype.optional is False: # If array cannot be None - raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"') - # Otherwise, None values are passed as null pointers below - elif isinstance(arg, ctypes._Pointer): - pass - elif isinstance(arg, str): - # Cast to bytes - result = ctypes.c_char_p(arg.encode('utf-8')) - else: - raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"') - elif is_array and not is_dtArray: - # GPU scalars and return values are pointers, so this is fine - if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): - raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"') - elif (is_dtArray and is_ndarray and not isinstance(atype, ContainerArray) - and atype.dtype.as_numpy_dtype() != arg.dtype): - # Make exception for vector types - if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): - pass - else: - print(f'WARNING: Passing {arg.dtype} array argument "{a}" to a {atype.dtype.type.__name__} array') - elif is_dtArray and is_ndarray and arg.base is not None and not '__return' in a and no_view_arguments: - raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' - 'programs is not allowed in order to retain analyzability. ' - 'Please make a copy with "numpy.copy(...)". If you know what ' - 'you are doing, you can override this error in the ' - 'configuration by setting compiler.allow_view_arguments ' - 'to True.') - elif (not isinstance(atype, (Array, Structure)) and not isinstance(atype.dtype, dtypes.callback) - and not isinstance(arg, (atype.dtype.type, sp.Basic)) - and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): - is_int = isinstance(arg, int) - if is_int and atype.dtype.type == np.int64: - pass - elif (is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): - pass - elif (is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): - pass - elif isinstance(arg, float) and atype.dtype.type == np.float64: - pass - elif isinstance(arg, bool) and atype.dtype.type == np.bool_: - pass - elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string: - if arg is None: - result = ctypes.c_char_p(None) - else: - # Cast to bytes - result = ctypes.c_char_p(arg.encode('utf-8')) - else: - warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') - result = atype.dtype.type(arg) - - # Call a wrapper function to make NumPy arrays from pointers. - if isinstance(argtype.dtype, dtypes.callback): - result = argtype.dtype.get_trampoline(result, symbols or {}, callback_retval_references) - # List to array - elif isinstance(result, list) and isinstance(argtype, Array): - result = np.array(result, dtype=argtype.dtype.type) - # Null pointer - elif result is None and isinstance(argtype, Array): - result = ctypes.c_void_p(0) - - # Retain only the element datatype for upcoming checks and casts - actype = argtype.dtype.as_ctypes() - - try: - if dtypes.is_array(result): # `c_void_p` is subclass of `ctypes._SimpleCData`. - result = ctypes.c_void_p(dtypes.array_interface_ptr(result, atype.storage)) - elif not isinstance(result, (ctypes._SimpleCData, ctypes._Pointer)): - result = actype(result) - else: - pass - except TypeError as ex: - raise TypeError(f'Invalid type for scalar argument "{a}": {ex}') - - return result diff --git a/dace/data/creation.py b/dace/data/creation.py new file mode 100644 index 0000000000..04f8971f9e --- /dev/null +++ b/dace/data/creation.py @@ -0,0 +1,243 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Data descriptor creation functions. + +This module contains functions for creating data descriptors from arbitrary objects, +as well as functions for creating arrays from descriptors. +""" +import ctypes + +from numbers import Number +from typing import Any, Dict, Optional, Tuple + +import numpy as np + +try: + from numpy.typing import ArrayLike +except (ModuleNotFoundError, ImportError): + ArrayLike = Any + +from dace import dtypes, symbolic +from dace.data.core import Array, Data, Scalar + + +def create_datadescriptor(obj, no_custom_desc=False): + """ Creates a data descriptor from various types of objects. + + :see: dace.data.Data + """ + if isinstance(obj, Data): + return obj + elif not no_custom_desc and hasattr(obj, '__descriptor__'): + return obj.__descriptor__() + elif not no_custom_desc and hasattr(obj, 'descriptor'): + return obj.descriptor + elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor": + # special case for torch tensors. Maybe __array__ could be used here for a more + # general solution, but torch doesn't support __array__ for cuda tensors. + try: + # If torch is importable, define translations between typeclasses and torch types. These are reused by daceml. + # conversion happens here in pytorch: + # https://github.com/pytorch/pytorch/blob/143ef016ee1b6a39cf69140230d7c371de421186/torch/csrc/utils/tensor_numpy.cpp#L237 + import torch + TYPECLASS_TO_TORCH_DTYPE = { + dtypes.bool_: torch.bool, + dtypes.int8: torch.int8, + dtypes.int16: torch.int16, + dtypes.int32: torch.int32, + dtypes.int64: torch.int64, + dtypes.uint8: torch.uint8, + dtypes.float16: torch.float16, + dtypes.float32: torch.float32, + dtypes.float64: torch.float64, + dtypes.complex64: torch.complex64, + dtypes.complex128: torch.complex128, + } + + TORCH_DTYPE_TO_TYPECLASS = {v: k for k, v in TYPECLASS_TO_TORCH_DTYPE.items()} + + storage = dtypes.StorageType.GPU_Global if obj.device.type == 'cuda' else dtypes.StorageType.Default + + return Array(dtype=TORCH_DTYPE_TO_TYPECLASS[obj.dtype], + strides=obj.stride(), + shape=tuple(obj.shape), + storage=storage) + except ImportError: + raise ValueError("Attempted to convert a torch.Tensor, but torch could not be imported") + elif dtypes.is_array(obj) and (hasattr(obj, '__array_interface__') or hasattr(obj, '__cuda_array_interface__')): + if dtypes.is_gpu_array(obj): + interface = obj.__cuda_array_interface__ + storage = dtypes.StorageType.GPU_Global + else: + interface = obj.__array_interface__ + storage = dtypes.StorageType.Default + + if hasattr(obj, 'dtype') and obj.dtype.fields is not None: # Struct + dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) + else: + if np.dtype(interface['typestr']).type is np.void: # Struct from __array_interface__ + if 'descr' in interface: + dtype = dtypes.struct('unnamed', **{ + k: dtypes.typeclass(np.dtype(v).type) + for k, v in interface['descr'] + }) + else: + raise TypeError(f'Cannot infer data type of array interface object "{interface}"') + else: + dtype = dtypes.typeclass(np.dtype(interface['typestr']).type) + itemsize = np.dtype(interface['typestr']).itemsize + if len(interface['shape']) == 0: + return Scalar(dtype, storage=storage) + return Array(dtype=dtype, + shape=interface['shape'], + strides=(tuple(s // itemsize for s in interface['strides']) if interface['strides'] else None), + storage=storage) + elif isinstance(obj, (list, tuple)): + # Lists and tuples are cast to numpy + obj = np.array(obj) + + if obj.dtype.fields is not None: # Struct + dtype = dtypes.struct('unnamed', **{k: dtypes.typeclass(v[0].type) for k, v in obj.dtype.fields.items()}) + else: + dtype = dtypes.typeclass(obj.dtype.type) + return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape) + elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray": + # special case for CuPy and HIP, which does not support __cuda_array_interface__ + storage = dtypes.StorageType.GPU_Global + dtype = dtypes.typeclass(obj.dtype.type) + itemsize = obj.itemsize + return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage) + elif symbolic.issymbolic(obj): + return Scalar(symbolic.symtype(obj)) + elif isinstance(obj, dtypes.typeclass): + return Scalar(obj) + elif (obj is int or obj is float or obj is complex or obj is bool or obj is None): + return Scalar(dtypes.typeclass(obj)) + elif isinstance(obj, type) and issubclass(obj, np.number): + return Scalar(dtypes.typeclass(obj)) + elif isinstance(obj, (Number, np.number, np.bool_)): + return Scalar(dtypes.typeclass(type(obj))) + elif obj is type(None): + # NoneType is void * + return Scalar(dtypes.pointer(dtypes.typeclass(None))) + elif isinstance(obj, str) or obj is str: + return Scalar(dtypes.string) + elif callable(obj): + # Cannot determine return value/argument types from function object + return Scalar(dtypes.callback(None)) + + raise TypeError(f'Could not create a DaCe data descriptor from object {obj}. ' + 'If this is a custom object, consider creating a `__descriptor__` ' + 'adaptor method to the type hint or object itself.') + + +def make_array_from_descriptor(descriptor: Array, + original_array: Optional[ArrayLike] = None, + symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: + """ + Creates an array that matches the given data descriptor, and optionally copies another array to it. + + :param descriptor: The data descriptor to create the array from. + :param original_array: An optional array to fill the content of the return value with. + :param symbols: An optional symbol mapping between symbol names and their values. Used for creating arrays + with symbolic sizes. + :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides. + """ + symbols = symbols or {} + + free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() + if free_syms: + raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') + + if descriptor.storage == dtypes.StorageType.GPU_Global: + try: + import cupy as cp + except (ImportError, ModuleNotFoundError): + raise NotImplementedError('GPU memory can only be allocated in Python if cupy is installed') + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = cp.ndarray(shape=[total_size], dtype=dtype) + view = cp.ndarray(shape=shape, + dtype=dtype, + memptr=buffer.data, + strides=[s * dtype.itemsize for s in strides]) + return view + + def copy_array(dst, src): + dst[:] = cp.asarray(src) + + elif descriptor.storage == dtypes.StorageType.FPGA_Global: + raise TypeError('Cannot allocate FPGA array in Python') + else: + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = np.ndarray([total_size], dtype=dtype) + view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) + return view + + def copy_array(dst, src): + dst[:] = src + + # Make numpy array from data descriptor + npdtype = descriptor.dtype.as_numpy_dtype() + evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) + evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) + evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) + view = create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) + if original_array is not None: + copy_array(view, original_array) + + return view + + +def make_reference_from_descriptor(descriptor: Array, + original_array: ctypes.c_void_p, + symbols: Optional[Dict[str, Any]] = None) -> ArrayLike: + """ + Creates an array that matches the given data descriptor from the given pointer. Shares the memory + with the argument (does not create a copy). + + :param descriptor: The data descriptor to create the array from. + :param original_array: The array whose memory the return value would be used in. + :param symbols: An optional symbol mapping between symbol names and their values. Used for referencing arrays + with symbolic sizes. + :return: A NumPy-compatible array (CuPy for GPU storage) with the specified size and strides, sharing memory + with the pointer specified in ``original_array``. + """ + symbols = symbols or {} + + original_array: int = ctypes.cast(original_array, ctypes.c_void_p).value + + free_syms = set(map(str, descriptor.free_symbols)) - symbols.keys() + if free_syms: + raise NotImplementedError(f'Cannot make Python references to arrays with undefined symbolic sizes: {free_syms}') + + if descriptor.storage == dtypes.StorageType.GPU_Global: + try: + import cupy as cp + except (ImportError, ModuleNotFoundError): + raise NotImplementedError('GPU memory can only be referenced in Python if cupy is installed') + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = dtypes.ptrtocupy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) + view = cp.ndarray(shape=shape, + dtype=dtype, + memptr=buffer.data, + strides=[s * dtype.itemsize for s in strides]) + return view + + elif descriptor.storage == dtypes.StorageType.FPGA_Global: + raise TypeError('Cannot reference FPGA array in Python') + else: + + def create_array(shape: Tuple[int], dtype: np.dtype, total_size: int, strides: Tuple[int]) -> ArrayLike: + buffer = dtypes.ptrtonumpy(original_array, descriptor.dtype.as_ctypes(), (total_size, )) + view = np.ndarray(shape, dtype, buffer=buffer, strides=[s * dtype.itemsize for s in strides]) + return view + + # Make numpy array from data descriptor + npdtype = descriptor.dtype.as_numpy_dtype() + evaluated_shape = tuple(symbolic.evaluate(s, symbols) for s in descriptor.shape) + evaluated_size = symbolic.evaluate(descriptor.total_size, symbols) + evaluated_strides = tuple(symbolic.evaluate(s, symbols) for s in descriptor.strides) + return create_array(evaluated_shape, npdtype, evaluated_size, evaluated_strides) diff --git a/dace/data/ctypes_interop.py b/dace/data/ctypes_interop.py new file mode 100644 index 0000000000..d9dfba58e1 --- /dev/null +++ b/dace/data/ctypes_interop.py @@ -0,0 +1,133 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Ctypes interoperability for data descriptors. + +This module contains functions for converting data descriptors to ctypes. +""" +import ctypes +import warnings + +from typing import Any, Dict, List, Optional + +import numpy as np +import sympy as sp + +from dace import config, dtypes, symbolic + + +def make_ctypes_argument(arg: Any, + argtype: 'Data', + name: Optional[str] = None, + allow_views: Optional[bool] = None, + symbols: Optional[Dict[str, Any]] = None, + callback_retval_references: Optional[List[Any]] = None) -> Any: + """ + Converts a given argument to the expected ``ctypes`` type for passing to compiled SDFG functions. + + :param arg: The argument to convert. + :param argtype: The expected data descriptor type of the argument. + :param name: The name of the argument (for error messages). + :param allow_views: Whether to allow views and references as input. If False, raises an error if a view or + reference is passed. If None (default), uses the global configuration setting + ``compiler.allow_view_arguments``. + :param symbols: An optional symbol mapping between symbol names and their values. Used for evaluating symbolic + sizes in callback arguments. + :param callback_retval_references: A list to store references to callback return values (to avoid garbage + collection of said return values). This object must be kept alive until the + SDFG call is complete. + :return: The argument converted to the appropriate ctypes type. + """ + # Import here to avoid circular imports + from dace.data.core import Array, ContainerArray, Structure + + if allow_views is None: + no_view_arguments = not config.Config.get_bool('compiler', 'allow_view_arguments') + else: + no_view_arguments = not allow_views + a = name or '' + atype = argtype + + result = arg + is_array = dtypes.is_array(arg) + is_ndarray = isinstance(arg, np.ndarray) + is_dtArray = isinstance(argtype, Array) + if not is_array and is_dtArray: + if isinstance(arg, list): + print(f'WARNING: Casting list argument "{a}" to ndarray') + elif arg is None: + if atype.optional is False: # If array cannot be None + raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"') + # Otherwise, None values are passed as null pointers below + elif isinstance(arg, ctypes._Pointer): + pass + elif isinstance(arg, str): + # Cast to bytes + result = ctypes.c_char_p(arg.encode('utf-8')) + else: + raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"') + elif is_array and not is_dtArray: + # GPU scalars and return values are pointers, so this is fine + if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): + raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"') + elif (is_dtArray and is_ndarray and not isinstance(atype, ContainerArray) + and atype.dtype.as_numpy_dtype() != arg.dtype): + # Make exception for vector types + if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): + pass + else: + print(f'WARNING: Passing {arg.dtype} array argument "{a}" to a {atype.dtype.type.__name__} array') + elif is_dtArray and is_ndarray and arg.base is not None and not '__return' in a and no_view_arguments: + raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' + 'programs is not allowed in order to retain analyzability. ' + 'Please make a copy with "numpy.copy(...)". If you know what ' + 'you are doing, you can override this error in the ' + 'configuration by setting compiler.allow_view_arguments ' + 'to True.') + elif (not isinstance(atype, (Array, Structure)) and not isinstance(atype.dtype, dtypes.callback) + and not isinstance(arg, (atype.dtype.type, sp.Basic)) + and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): + is_int = isinstance(arg, int) + if is_int and atype.dtype.type == np.int64: + pass + elif (is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): + pass + elif (is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): + pass + elif isinstance(arg, float) and atype.dtype.type == np.float64: + pass + elif isinstance(arg, bool) and atype.dtype.type == np.bool_: + pass + elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string: + if arg is None: + result = ctypes.c_char_p(None) + else: + # Cast to bytes + result = ctypes.c_char_p(arg.encode('utf-8')) + else: + warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') + result = atype.dtype.type(arg) + + # Call a wrapper function to make NumPy arrays from pointers. + if isinstance(argtype.dtype, dtypes.callback): + result = argtype.dtype.get_trampoline(result, symbols or {}, callback_retval_references) + # List to array + elif isinstance(result, list) and isinstance(argtype, Array): + result = np.array(result, dtype=argtype.dtype.type) + # Null pointer + elif result is None and isinstance(argtype, Array): + result = ctypes.c_void_p(0) + + # Retain only the element datatype for upcoming checks and casts + actype = argtype.dtype.as_ctypes() + + try: + if dtypes.is_array(result): # `c_void_p` is subclass of `ctypes._SimpleCData`. + result = ctypes.c_void_p(dtypes.array_interface_ptr(result, atype.storage)) + elif not isinstance(result, (ctypes._SimpleCData, ctypes._Pointer)): + result = actype(result) + else: + pass + except TypeError as ex: + raise TypeError(f'Invalid type for scalar argument "{a}": {ex}') + + return result diff --git a/dace/data/ml.py b/dace/data/ml.py new file mode 100644 index 0000000000..5f26c6e3a1 --- /dev/null +++ b/dace/data/ml.py @@ -0,0 +1,113 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +ML-related data descriptors. + +This module contains data descriptors that are specific to machine learning workflows, +such as ParameterArray for automatic differentiation. +""" +import copy + +from dace import properties +from dace.data.core import Array +from dace.sdfg import SDFG, nodes + + +@properties.make_properties +class ParameterArray(Array): + """ + An array for which a gradient can be computed. + """ + # since this can be None, this is not a DataProperty + gradient = properties.Property(dtype=str, desc="The corresponding gradient buffer", default=None, allow_none=True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __repr__(self): + return "Parameter" + Array.__repr__(self) + + def add_gradient_buffer(self, sdfg: SDFG, name: str) -> str: + """ + Find or create a gradient buffer for the parameter in the given SDFG. + + :param sdfg: the SDFG containing the parameter + :param name: the name of the parameter + :return: the name of the gradient buffer + """ + + if self.gradient: + return self.gradient + + # First, check if this array already has a gradient buffer in a nested + # SDFG. This happens, for example when pytorch modules are used in the + # frontend. In that case: + # 1. the parser assembles the closure of the module, which adds + # descriptors for all the parameters and their gradients (if they + # are required). + # 2. A nested sdfg is added for the module, with those array names. + # 3. The DaceProgram will then pass these arrays in when the + # DaceProgram is called, using the names from the closure that + # match the names from the NestedSDFG + # 4. When parsing the backward nodes, we want the gradient buffers in + # the closure to match the gradient buffers that we pass in. Thus, + # we need to make sure that we use the same name as the NestedSDFG + # + # Note that we do not currently do any nesting beyond this level, + # because nested modules are converted to one SDFG. + + cands = set() + for state in sdfg.nodes(): + for node in state.nodes(): + if not isinstance(node, nodes.NestedSDFG): + continue + + nested_names = set() + + for edge in state.in_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + for edge in state.out_edges(node): + if edge.data.data == name: + nested_names.add(edge.dst_conn) + + for name in nested_names: + nested_desc = node.sdfg.arrays[name] + if isinstance(nested_desc, ParameterArray) and nested_desc.gradient: + cands.add(nested_desc.gradient) + + if len(cands) > 1: + raise ValueError("Multiple gradient buffers found for parameter " + name) + elif len(cands) == 1: + # we found a name of a gradient buffer in a nested SDFG: + # reuse the same name in the outer sdfg if there is a matching descriptor + grad_name = cands.pop() + if grad_name in sdfg.arrays: + self.gradient = grad_name + return grad_name + else: + grad_name = sdfg._find_new_name('gradient_' + name) + + # Create a gradient buffer for the array + grad_desc = copy.deepcopy(self) + grad_desc.__class__ = Array + grad_desc.transient = True + grad_name = sdfg.add_datadesc(grad_name, grad_desc, find_new_name=True) + self.gradient = grad_name + return grad_name + + @staticmethod + def make_parameter(sdfg: SDFG, name: str): + """ + Converts an existing array into a parameter, without copying. + + :param sdfg: the SDFG containing the array. + :param name: the name of the array. + """ + desc = sdfg.arrays[name] + if isinstance(desc, ParameterArray): + return + + new_desc = copy.deepcopy(desc) + new_desc.__class__ = ParameterArray + new_desc.gradient = None + sdfg.arrays[name] = new_desc diff --git a/dace/data/tensor.py b/dace/data/tensor.py new file mode 100644 index 0000000000..4444c31f0f --- /dev/null +++ b/dace/data/tensor.py @@ -0,0 +1,698 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tensor data descriptors for sparse tensor formats. + +This module contains classes for representing various sparse tensor storage formats +based on the abstraction described in [https://doi.org/10.1145/3276493]. +""" +import aenum + +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union + +from dace import dtypes, serialize, symbolic +from dace.data.core import Data, Scalar, Structure +from dace.properties import ListProperty, Property, ShapeProperty, SymbolicProperty, TypeClassProperty, make_properties + + +class TensorIterationTypes(aenum.AutoNumberEnum): + """ + Types of tensor iteration capabilities. + + Value (Coordinate Value Iteration) allows to directly iterate over + coordinates such as when using the Dense index type. + + Position (Coordinate Position Iteratation) iterates over coordinate + positions, at which the actual coordinates lie. This is for example the case + with a compressed index, in which the pos array enables one to iterate over + the positions in the crd array that hold the actual coordinates. + """ + Value = () + Position = () + + +class TensorAssemblyType(aenum.AutoNumberEnum): + """ + Types of possible assembly strategies for the individual indices. + + NoAssembly: Assembly is not possible as such. + + Insert: index allows inserting elements at random (e.g. Dense) + + Append: index allows appending to a list of existing coordinates. Depending + on append order, this affects whether the index is ordered or not. This + could be changed by sorting the index after assembly + """ + NoAssembly = () + Insert = () + Append = () + + +class TensorIndex(ABC): + """ + Abstract base class for tensor index implementations. + """ + + @property + @abstractmethod + def iteration_type(self) -> TensorIterationTypes: + """ + Iteration capability supported by this index. + + See TensorIterationTypes for reference. + """ + pass + + @property + @abstractmethod + def locate(self) -> bool: + """ + True if the index supports locate (aka random access), False otw. + """ + pass + + @property + @abstractmethod + def assembly(self) -> TensorAssemblyType: + """ + What assembly type is supported by the index. + + See TensorAssemblyType for reference. + """ + pass + + @property + @abstractmethod + def full(self) -> bool: + """ + True if the level is full, False otw. + + A level is considered full if it encompasses all valid coordinates along + the corresponding tensor dimension. + """ + pass + + @property + @abstractmethod + def ordered(self) -> bool: + """ + True if the level is ordered, False otw. + + A level is ordered when all coordinates that share the same ancestor are + ordered by increasing value (e.g. in typical CSR). + """ + pass + + @property + @abstractmethod + def unique(self) -> bool: + """ + True if coordinate in the level are unique, False otw. + + A level is considered unique if no collection of coordinates that share + the same ancestor contains duplicates. In CSR this is True, in COO it is + not. + """ + pass + + @property + @abstractmethod + def branchless(self) -> bool: + """ + True if the level doesn't branch, false otw. + + A level is considered branchless if no coordinate has a sibling (another + coordinate with same ancestor) and all coordinates in parent level have + a child. In other words if there is a bijection between the coordinates + in this level and the parent level. An example of the is the Singleton + index level in the COO format. + """ + pass + + @property + @abstractmethod + def compact(self) -> bool: + """ + True if the level is compact, false otw. + + A level is compact if no two coordinates are separated by an unlabled + node that does not encode a coordinate. An example of a compact level + can be found in CSR, while the DIA formats range and offset levels are + not compact (they have entries that would coorespond to entries outside + the tensors index range, e.g. column -1). + """ + pass + + @abstractmethod + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + """ + Generates the fields needed for the index. + + :return: a Dict of fields that need to be present in the struct + """ + pass + + def to_json(self): + attrs = serialize.all_properties_to_json(self) + + retdict = {"type": type(self).__name__, "attributes": attrs} + + return retdict + + @classmethod + def from_json(cls, json_obj, context=None): + + # Selecting proper subclass + if json_obj['type'] == "TensorIndexDense": + self = TensorIndexDense.__new__(TensorIndexDense) + elif json_obj['type'] == "TensorIndexCompressed": + self = TensorIndexCompressed.__new__(TensorIndexCompressed) + elif json_obj['type'] == "TensorIndexSingleton": + self = TensorIndexSingleton.__new__(TensorIndexSingleton) + elif json_obj['type'] == "TensorIndexRange": + self = TensorIndexRange.__new__(TensorIndexRange) + elif json_obj['type'] == "TensorIndexOffset": + self = TensorIndexOffset.__new__(TensorIndexOffset) + else: + raise TypeError(f"Invalid data type, got: {json_obj['type']}") + + serialize.set_properties_from_json(self, json_obj['attributes'], context=context) + + return self + + +@make_properties +class TensorIndexDense(TensorIndex): + """ + Dense tensor index. + + Levels of this type encode the the coordinate in the interval [0, N), where + N is the size of the corresponding dimension. This level doesn't need any + index structure beyond the corresponding dimension size. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return True + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Insert + + @property + def full(self) -> bool: + return True + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return {} + + def __repr__(self) -> str: + s = "Dense" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexCompressed(TensorIndex): + """ + Tensor level that stores coordinates in segmented array. + + Levels of this type are compressed using a segented array. The pos array + holds the start and end positions of the segment in the crd (coordinate) + array that holds the child coordinates corresponding the parent. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return True + + def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_pos": dtypes.int32[dummy_symbol], # TODO (later) choose better length + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Compressed" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexSingleton(TensorIndex): + """ + Tensor index that encodes a single coordinate per parent coordinate. + + Levels of this type hold exactly one coordinate for every coordinate in the + parent level. An example can be seen in the COO format, where every + coordinate but the first is encoded in this manner. + """ + + _full = Property(dtype=bool, default=False) + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.Append + + @property + def full(self) -> bool: + return self._full + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return True + + def __init__(self, full: bool = False, ordered: bool = True, unique: bool = True): + self._full = full + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_crd": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Singleton" + + non_defaults = [] + if self._full: + non_defaults.append("F") + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexRange(TensorIndex): + """ + Tensor index that encodes a interval of coordinates for every parent. + + The interval is computed from an offset for each parent together with the + tensor dimension size of this level (M) and the parent level (N) parents + corresponding tensor. Given the parent coordinate i, the level encodes the + range of coordinates between max(0, -offset[i]) and min(N, M - offset[i]). + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Value + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return False + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Range" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class TensorIndexOffset(TensorIndex): + """ + Tensor index that encodes the next coordinates as offset from parent. + + Given a parent coordinate i and an offset index k, the level encodes the + coordinate j = i + offset[k]. + """ + + _ordered = Property(dtype=bool, default=False) + _unique = Property(dtype=bool, default=False) + + @property + def iteration_type(self) -> TensorIterationTypes: + return TensorIterationTypes.Position + + @property + def locate(self) -> bool: + return False + + @property + def assembly(self) -> TensorAssemblyType: + return TensorAssemblyType.NoAssembly + + @property + def full(self) -> bool: + return False + + @property + def ordered(self) -> bool: + return self._ordered + + @property + def unique(self) -> bool: + return self._unique + + @property + def branchless(self) -> bool: + return True + + @property + def compact(self) -> bool: + return False + + def __init__(self, ordered: bool = True, unique: bool = True): + self._ordered = ordered + self._unique = unique + + def fields(self, lvl: int, dummy_symbol: symbolic.SymExpr) -> Dict[str, Data]: + return { + f"idx{lvl}_offset": dtypes.int32[dummy_symbol], # TODO (later) choose better length + } + + def __repr__(self) -> str: + s = "Offset" + + non_defaults = [] + if not self._ordered: + non_defaults.append("¬O") + if not self._unique: + non_defaults.append("¬U") + + if len(non_defaults) > 0: + s += f"({','.join(non_defaults)})" + + return s + + +@make_properties +class Tensor(Structure): + """ + Abstraction for Tensor storage format. + + This abstraction is based on [https://doi.org/10.1145/3276493]. + """ + + value_dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) + tensor_shape = ShapeProperty(default=[]) + indices = ListProperty(element_type=TensorIndex) + index_ordering = ListProperty(element_type=symbolic.SymExpr) + value_count = SymbolicProperty(default=0) + + def __init__(self, + value_dtype: dtypes.Typeclasses, + tensor_shape, + indices: List[Tuple[TensorIndex, Union[int, symbolic.SymExpr]]], + value_count: symbolic.SymExpr, + name: str, + transient: bool = False, + storage: dtypes.StorageType = dtypes.StorageType.Default, + location: Dict[str, str] = None, + lifetime: dtypes.AllocationLifetime = dtypes.AllocationLifetime.Scope, + debuginfo: dtypes.DebugInfo = None): + """ + Constructor for Tensor storage format. + + Below are examples of common matrix storage formats: + + .. code-block:: python + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + + csr = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 0), (dace.data.Compressed(), 1)], + nnz, + "CSR_Matrix", + ) + + csc = dace.data.Tensor( + dace.float32, + (M, N), + [(dace.data.Dense(), 1), (dace.data.Compressed(), 0)], + nnz, + "CSC_Matrix", + ) + + coo = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(), 1), + ], + nnz, + "CSC_Matrix", + ) + + num_diags = dace.symbol('num_diags') # number of diagonals stored + + diag = dace.data.Tensor( + dace.float32, + (M, N), + [ + (dace.data.Dense(), num_diags), + (dace.data.Range(), 0), + (dace.data.Offset(), 1), + ], + nnz, + "DIA_Matrix", + ) + + Below you can find examples of common 3rd order tensor storage formats: + + .. code-block:: python + + I, J, K, nnz = (dace.symbol(s) for s in ('I', 'J', 'K', 'nnz')) + + coo = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(unique=False), 0), + (dace.data.Singleton(unique=False), 1), + (dace.data.Singleton(), 2), + ], + nnz, + "COO_3D_Tensor", + ) + + csf = dace.data.Tensor( + dace.float32, + (I, J, K), + [ + (dace.data.Compressed(), 0), + (dace.data.Compressed(), 1), + (dace.data.Compressed(), 2), + ], + nnz, + "CSF_3D_Tensor", + ) + + :param value_type: data type of the explicitly stored values. + :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) + :param indices: + a list of tuples, each tuple represents a level in the tensor + storage hirachy, specifying the levels tensor index type, and the + corresponding dimension this level encodes (as index of the + tensor_shape tuple above). The order of the dimensions may differ + from the logical shape of the tensor, e.g. as seen in the CSC + format. If an index's dimension is unrelated to the tensor shape + (e.g. in diagonal format where the first index's dimension is the + number of diagonals stored), a symbol can be specified instead. + :param value_count: number of explicitly stored values. + :param name: name of resulting struct. + :param others: See Structure class for remaining arguments + """ + + self.value_dtype = value_dtype + self.tensor_shape = tensor_shape + self.value_count = value_count + + indices, index_ordering = zip(*indices) + self.indices, self.index_ordering = list(indices), list(index_ordering) + + num_dims = len(tensor_shape) + dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] + + # all tensor dimensions must occure exactly once in indices + if not sorted(dimension_order) == list(range(num_dims)): + raise TypeError((f"All tensor dimensions must be refferenced exactly once in " + f"tensor indices. (referenced dimensions: {dimension_order}; " + f"tensor dimensions: {list(range(num_dims))})")) + + # assembling permanent and index specific fields + fields = dict( + order=Scalar(dtypes.int32), + dim_sizes=dtypes.int32[num_dims], + value_count=value_count, + values=dtypes.float32[value_count], + ) + + for (lvl, index) in enumerate(indices): + fields.update(index.fields(lvl, value_count)) + + super(Tensor, self).__init__(fields, name, transient, storage, location, lifetime, debuginfo) + + def __repr__(self): + return f"{self.name} (dtype: {self.value_dtype}, shape: {list(self.tensor_shape)}, indices: {self.indices})" + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'Tensor': + raise TypeError("Invalid data type") + + # Create dummy object + tensor = Tensor.__new__(Tensor) + serialize.set_properties_from_json(tensor, json_obj, context=context) + + return tensor diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index df1c8de34e..d33b0150f3 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -9,7 +9,7 @@ import dace from dace import dtypes, subsets, symbolic -from dace.data import _prod as prod +from dace.utils import prod from dace.sdfg.nodes import AccessNode from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.memlet import Memlet diff --git a/dace/libraries/linalg/nodes/tensordot.py b/dace/libraries/linalg/nodes/tensordot.py index e2e6e54e46..03b89a5a2c 100644 --- a/dace/libraries/linalg/nodes/tensordot.py +++ b/dace/libraries/linalg/nodes/tensordot.py @@ -4,7 +4,7 @@ import dace.libraries.linalg.environments as environments from dace import library, nodes, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.blas import blas_helpers from dace.symbolic import symstr from dace.transformation.transformation import ExpandTransformation diff --git a/dace/libraries/mpi/nodes/gather.py b/dace/libraries/mpi/nodes/gather.py index 8ad2b0df8b..b231ff7cee 100644 --- a/dace/libraries/mpi/nodes/gather.py +++ b/dace/libraries/mpi/nodes/gather.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.symbolic import symstr diff --git a/dace/libraries/mpi/nodes/redistribute.py b/dace/libraries/mpi/nodes/redistribute.py index e58ed544b7..19dddc0f8b 100644 --- a/dace/libraries/mpi/nodes/redistribute.py +++ b/dace/libraries/mpi/nodes/redistribute.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties, subsets, symbolic -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.transformation.transformation import ExpandTransformation diff --git a/dace/libraries/mpi/nodes/scatter.py b/dace/libraries/mpi/nodes/scatter.py index 04367cbbfb..59cc54f3a7 100644 --- a/dace/libraries/mpi/nodes/scatter.py +++ b/dace/libraries/mpi/nodes/scatter.py @@ -1,6 +1,6 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. from dace import dtypes, library, properties -from dace.data import _prod +from dace.utils import prod as _prod from dace.libraries.mpi import utils from dace.sdfg import nodes from dace.symbolic import symstr diff --git a/dace/optimization/data_layout_tuner.py b/dace/optimization/data_layout_tuner.py index 1dab5b3f17..f45325f313 100644 --- a/dace/optimization/data_layout_tuner.py +++ b/dace/optimization/data_layout_tuner.py @@ -5,7 +5,7 @@ import copy import itertools -from typing import Generator, Optional, Tuple, Dict, List, Sequence, Set +from typing import Any, Generator, Optional, Tuple, Dict, List, Sequence, Set from dace import data as dt, SDFG, dtypes from dace.optimization import cutout_tuner @@ -19,6 +19,11 @@ except (ImportError, ModuleNotFoundError): tqdm = lambda x, **kwargs: x +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore + class TuningGroups(enum.Enum): Separate = enum.auto() @@ -111,7 +116,7 @@ def pre_evaluate(self, cutout: dace.SDFG, dreport: data_report.InstrumentedDataR cutout.instrument = self.instrument # Prepare original arguments to sub-SDFG from instrumented data report - arguments: Dict[str, dt.ArrayLike] = {} + arguments: Dict[str, ArrayLike] = {} for cstate in cutout.nodes(): for dnode in cstate.data_nodes(): if cutout.arrays[dnode.data].transient: diff --git a/dace/properties.py b/dace/properties.py index b0658d8572..cddc26ab8b 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -123,10 +123,16 @@ def tmp_func(self): self._category = category if desc is not None and len(desc) > 0: self.__doc__ = desc - elif self.dtype is not None: - self.__doc__ = "Object property of type %s" % self.dtype.__name__ else: - self.__doc__ = "Object property of type %s" % type(self).__name__ + try: + dtype = self.dtype + if dtype is not None: + self.__doc__ = "Object property of type %s" % dtype.__name__ + else: + self.__doc__ = "Object property of type %s" % type(self).__name__ + except (ImportError, AttributeError): + # Handle circular import case - defer docstring generation + self.__doc__ = "Object property of type %s" % type(self).__name__ def __get__(self, obj, objtype=None) -> T: if obj is None: diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index 432c765aa0..f3a0458f7e 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -8,7 +8,8 @@ import sympy as sp from collections import deque import copy -from typing import Deque, Dict, List, Set, Tuple, Union, Optional, Any +from typing import Any, Deque, Dict, List, Set, Tuple, Union, Optional +from numbers import Number from dace import data, DataInstrumentationType from dace.sdfg import nodes as nd, SDFG, SDFGState, utils as sdutil, InterstateEdge from dace.memlet import Memlet @@ -19,6 +20,11 @@ from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.passes.analysis import StateReachability +try: + from numpy.typing import ArrayLike +except ImportError: + ArrayLike = Any # type: ignore + class SDFGCutout(SDFG): @@ -52,12 +58,12 @@ def _dry_run_base_sdfg(self, *args, **kwargs) -> None: self._instrument_base_sdfg() self._base_sdfg(*args, **kwargs) - def find_inputs(self, *args, **kwargs) -> Dict[str, Union[data.ArrayLike, data.Number]]: + def find_inputs(self, *args, **kwargs) -> Dict[str, Union[ArrayLike, Number]]: self._dry_run_base_sdfg(*args, **kwargs) drep = self._base_sdfg.get_instrumented_data() if drep: - vals: Dict[str, Union[data.ArrayLike, data.Number]] = dict() + vals: Dict[str, Union[ArrayLike, Number]] = dict() for ip in self.input_config.union(set(self.symbols)): val = drep.get_first_version(ip) vals[ip] = val diff --git a/dace/transformation/dataflow/map_distribution.py b/dace/transformation/dataflow/map_distribution.py index 5cd551d3ac..eaa8d07f2a 100644 --- a/dace/transformation/dataflow/map_distribution.py +++ b/dace/transformation/dataflow/map_distribution.py @@ -319,7 +319,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) - from dace.data import _prod + from dace.utils import prod as _prod # NOTE: Maps with step in their ranges are currently not supported if len(map_entry.map.params) == 2: diff --git a/dace/utils.py b/dace/utils.py new file mode 100644 index 0000000000..26e3661be8 --- /dev/null +++ b/dace/utils.py @@ -0,0 +1,58 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Utility functions for DaCe. + +This module provides general utility functions that are used across various parts of DaCe. +""" + +import math +from typing import Iterable, Sequence, Union + +import sympy + +# Type alias for numeric or symbolic values +NumericType = Union[int, float, sympy.Basic] + + +def prod(sequence: Iterable[NumericType], start: NumericType = 1) -> NumericType: + """ + Computes the product of a sequence of numbers or symbolic expressions. + + This function handles both numeric values and SymPy symbolic expressions, + making it suitable for use with DaCe's symbolic shape calculations. + + :param sequence: An iterable of numbers or symbolic expressions. + :param start: The starting value for the product (default: 1). + :return: The product of all elements in the sequence, multiplied by start. + Returns start if the sequence is empty. + """ + result = start + for item in sequence: + result = result * item + return result + + +def find_new_name(name: str, existing_names: Sequence[str]) -> str: + """ + Returns a name that matches the given ``name`` as a prefix, but does not + already exist in the given existing name set. The behavior is typically + to append an underscore followed by a unique (increasing) number. If the + name does not already exist in the set, it is returned as-is. + + :param name: The given name to find. + :param existing_names: The set of existing names. + :return: A new name that is not in existing_names. + """ + if name not in existing_names: + return name + cur_offset = 0 + new_name = name + '_' + str(cur_offset) + while new_name in existing_names: + cur_offset += 1 + new_name = name + '_' + str(cur_offset) + return new_name + + +def deduplicate(iterable): + """ Removes duplicates in the passed iterable. """ + return type(iterable)([i for i in sorted(set(iterable), key=lambda x: iterable.index(x))]) From cc59d7701ebb80a7d69a4d8a35afe964b999b1f9 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 8 Dec 2025 08:25:20 -0800 Subject: [PATCH 07/17] Support `dace.map` syntax for struct fields (#2187) --- dace/frontend/python/newast.py | 22 +++-- .../structures/structure_python_test.py | 89 +++++++++++++++++++ 2 files changed, 104 insertions(+), 7 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index b5e8c72ea8..807e1e80df 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2011,7 +2011,7 @@ def _parse_map_inputs(self, name: str, params: List[Tuple[str, str]], if symbolic.issymbolic(atom, self.sdfg.constants): # Check for undefined variables atomstr = str(atom) - if atomstr not in self.defined: + if atomstr not in self.defined and atomstr not in self.sdfg.arrays: raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols @@ -3245,8 +3245,14 @@ def _add_access( else: var_name = self.get_target_name() - parent_name = self.scope_vars[name] - parent_array = self.scope_arrays[parent_name] + parent_name = self.scope_vars[until(name, '.')] + if '.' in name: + struct_field = name[name.index('.'):] + parent_name += struct_field + scope_ndict = dace.sdfg.NestedDict(self.scope_arrays) + parent_array = scope_ndict[parent_name] + else: + parent_array = self.scope_arrays[parent_name] has_indirection = (_subset_has_indirection(rng, self) or _subset_is_local_symbol_dependent(rng, self)) strides = list(parent_array.strides) @@ -3419,7 +3425,7 @@ def _add_write_access(self, return self.accesses[(name, rng, 'w')] elif name in self.variables: return (self.variables[name], rng) - elif (name, rng, 'r') in self.accesses or name in self.scope_vars: + elif (name, rng, 'r') in self.accesses or until(name, '.') in self.scope_vars: return self._add_access(name, rng, 'w', target, new_name, arr_type) else: raise NotImplementedError @@ -3527,8 +3533,10 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): while isinstance(last_subscript.value, ast.Subscript): last_subscript = last_subscript.value if isinstance(target, ast.Subscript) and not isinstance(last_subscript.value, ast.Name): - store_target = copy.copy(last_subscript.value) - store_target.ctx = ast.Store() + store_target = astutils.copy_tree(last_subscript.value) + for n in ast.walk(store_target): # Recursively make attributes into stores + if hasattr(n, 'ctx'): + n.ctx = ast.Store() true_name = self.visit(store_target) # Refresh defined variables and arrays defined_vars = {**self.variables, **self.scope_vars} @@ -3736,7 +3744,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): raise IndexError('Boolean array indexing cannot be combined with indirect access') if self.nested and not new_data and not visited_target: - new_name, new_rng = self._add_write_access(name, rng, target) + new_name, new_rng = self._add_write_access(true_name, rng, target) # Local symbol or local data dependent if _subset_is_local_symbol_dependent(rng, self): new_rng = rng diff --git a/tests/python_frontend/structures/structure_python_test.py b/tests/python_frontend/structures/structure_python_test.py index 9505f8cab7..af317be7d8 100644 --- a/tests/python_frontend/structures/structure_python_test.py +++ b/tests/python_frontend/structures/structure_python_test.py @@ -257,6 +257,92 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): assert np.allclose(B, ref) +def test_write_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=9, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:9, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + +def test_readwrite_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "data2": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def copy_prog(bundle: Bundle) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = bundle.data2[index, :] + 5 + + data = np.zeros((10, 5), dtype=np.float32) + data2 = np.ones((10, 5), dtype=np.float32) + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + data2=data2.__array_interface__['data'][0], + size=6, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = 6.0 + + copy_prog.compile()(inp_struct, M=10, N=5) + + assert np.allclose(data, ref) + + +def test_write_structure_in_loop(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in range(bundle.size): + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=6, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + def test_struct_interface(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) CSR = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), @@ -370,6 +456,9 @@ def struct_recursive(A: Struct, B: Struct): test_local_structure() test_rgf() # test_read_structure_gpu() + test_write_structure_in_map() + test_readwrite_structure_in_map() + test_write_structure_in_loop() test_struct_interface() test_struct_recursive() test_struct_recursive_from_dataclass() From 5824ad0be3eb6e19ef476fccba9d52960799e223 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Mon, 8 Dec 2025 23:53:39 +0000 Subject: [PATCH 08/17] GPU codegen crashes and generates incorrect code with dynamic inputs to seq. maps inside GPU kernels or gpu dev. maps (#2088) GPU codegen crashes and generates incorrect code with dynamic inputs to seq. maps inside GPU kernels or gpu dev. maps --------- Co-authored-by: alexnick83 <31545860+alexnick83@users.noreply.github.com> Co-authored-by: Tal Ben-Nun --- dace/codegen/targets/cuda.py | 63 ++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index feb5193091..6de2cc0769 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1591,7 +1591,11 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(state, scope_entry): - kernel_args[str(e.src)] = e.src.desc(sdfg) + if e.data is None: + raise Exception("Dynamic map input's memlet can't be None") + data_name = e.data.data + data_desc = state.sdfg.arrays[data_name] + kernel_args[data_name] = data_desc # Add data from nested SDFGs to kernel arguments extra_call_args = [] @@ -1812,7 +1816,15 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # make sure dynamic map inputs are properly handled for e in dace.sdfg.dynamic_map_inputs(state, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" self._localcode.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) @@ -1883,7 +1895,15 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub callsite_stream.write( 'DACE_GPU_CHECK({backend}EventSynchronize(__state->gpu_context->events[{ev}]));'.format( ev=ev, backend=self.backend), cfg, state_id, [e.src, e.dst]) + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, node) @@ -2175,7 +2195,15 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S # handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(cfg.node(state_id), dfg_scope.source_nodes()[0]): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" kernel_stream.write( + comment_out_str + self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, dfg_scope.source_nodes()[0]) @@ -2353,9 +2381,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco # They define outside the schedule the bounds of the dynamic Map's for-loop invocation. # NOTE: The value of the dynamic Map's variable may differ inside and outside the schedule. for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) dynmap_var = scope_map.params[0] dynmap_begin = scope_map.range[0][0] @@ -2384,9 +2419,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco param=dynmap_var), cfg, state_id, scope_entry) for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) if dynmap_step != 1: callsite_stream.write( @@ -2419,9 +2461,16 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco # handle dynamic map inputs for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry): + if e.data is not None and e.data.data == e.dst_conn: + warnings.warn( + f"Dynamic map input name {e.data.data} is same as the dst connector. Will result in a name clash, omitting of code for this assignment is skipped." + ) + comment_out_str = "// Omitted name clash on dynamic map input\n//" + else: + comment_out_str = "" callsite_stream.write( - self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, - e.dst.in_connectors[e.dst_conn]), cfg, state_id, + comment_out_str + self._cpu_codegen.memlet_definition( + sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, scope_entry) # variables that need to be declared + the value they need to be initialized with From 312f37fa8b6b8fef40b84bdfbfb61e19ad56bda1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 8 Dec 2025 23:27:59 -0800 Subject: [PATCH 09/17] Modular Code Generation Docs: Add LowerConsume and remove numbering (#2246) Updated the documentation for proposed pass decomposition, including changes to pass names and descriptions for clarity. --- doc/design/codegen.md | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/doc/design/codegen.md b/doc/design/codegen.md index f236179c49..b4ff255fee 100644 --- a/doc/design/codegen.md +++ b/doc/design/codegen.md @@ -102,43 +102,43 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s ### Phase 1: Scheduling and Analysis Passes -#### 1. **ValidationPass** +#### **ValidationPass** - **Purpose**: Run SDFG validation prior to code generation - **Input**: Input SDFG - **Output**: None - **Current Location**: `validate.py` -#### 2. **TypeInferencePass** +#### **TypeInferencePass** - **Purpose**: Infer connector types and set default storage/schedule types - **Input**: Input SDFG - **Output**: SDFG with inferred types, pipeline_results["type_info"] - **Current Location**: `infer_types.py` functions -#### 3. **LibraryExpansionPass** +#### **LibraryExpansionPass** - **Purpose**: Expand all library nodes that haven't been expanded - **Input**: Type-inferred SDFG - **Output**: SDFG with expanded library nodes - **Current Location**: `sdfg.expand_library_nodes()` -#### 4. **TypeInferencePass** +#### **TypeInferencePass** - **Purpose**: After expanding library nodes, run a second type inference pass if the SDFG changed - **Input**: Library-expanded SDFG - **Output**: SDFG with inferred types, updated pipeline_results["type_info"] - **Current Location**: `infer_types.py` functions -#### 5. **MetadataCollectionPass** +#### **MetadataCollectionPass** - **Purpose**: Collect free symbols, argument lists, constants, shared transients - **Input**: Expanded SDFG - **Output**: pipeline_results["metadata"] = {symbols, arglist, constants, shared_transients} - **Current Location**: `DaCeCodeGenerator.__init__()` -#### 6. **ControlFlowRaising** +#### **ControlFlowRaising** - **Purpose**: Extract structured control flow from state machines, if Control Flow Regions were not already given - **Input**: SDFG - **Output**: SDFG with Control Flow Regions - **Current Location**: Already exists -#### 7. **AllocationAnalysisPass** +#### **AllocationAnalysisPass** - **Purpose**: Determine allocation lifetimes and scopes for all data containers - **Input**: SDFG with metadata - **Output**: SDFG with allocation/deallocation points stored in node metadata @@ -146,7 +146,7 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s these decisions. - **Current Location**: `DaCeCodeGenerator.determine_allocation_lifetime()` -#### 8. **StreamAssignmentPass** (mostly GPU-specific) +#### **StreamAssignmentPass** (mostly GPU-specific) - **Purpose**: Assign streams for concurrent execution. Currently used for CUDA/HIP streams but can apply to other architectures - **Input**: SDFG - **Output**: SDFG with stream assignments stored in node metadata @@ -159,33 +159,40 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s - **Purpose**: Perform preprocessing modifications on the SDFG based on the code generators that will be used next - **Examples**: `FPGAPreprocessingPass` for FPGAs, `StreamAssignmentPass` for GPUs, `CopyToMap` for heterogeneous targets in general (see below) -#### 9. **LowerAllocations** +#### **LowerConsume** +- **Purpose**: Convert Consume scopes into while loops or kernels, depending on the target (`LowerConsumeCPP`, `LowerConsumeGPU`) +- **Input**: SDFG with consume scopes +- **Output**: SDFG with control flow regions +- **Current Location**: Inline in code generators +- **Note**: This modifies the SDFG structure rather than generating code + +#### **LowerAllocations** - **Purpose**: Add allocation/deallocation annotations (e.g., as tasklets) to the SDFG for each scope - **Input**: SDFG with allocation analysis - **Output**: SDFG with allocation/deallocation tasklets inserted - **Current Location**: `allocate_arrays_in_scope()`, `deallocate_arrays_in_scope()` - **Note**: This modifies the SDFG structure rather than generating code -#### 10. **CopyToMap** +#### **CopyToMap** - **Purpose**: Convert nontrivial memory copies to Map nodes where needed - **Input**: SDFG with targets identified - **Output**: SDFG with transformed copies - **Current Location**: `cuda.py` preprocessing, various target preprocessors - **Applies To**: GPU strided copies, FPGA transfers -#### 11. **LowerTaskletLanguage** +#### **LowerTaskletLanguage** - **Purpose**: Convert Python/generic tasklets to tasklets in the target language (C++/CUDA/etc.) - **Input**: SDFG with tasklets - **Output**: SDFG with lowered tasklets - **Current Location**: Distributed across target generators -#### 12. **LowerMemlets** +#### **LowerMemlets** - **Purpose**: Lower high-level memlets to explicit copy operations - **Input**: SDFG with target analysis - **Output**: SDFG with explicit copies annotated (e.g., as tasklets) - **Current Location**: Embedded in target-specific copy generation -#### 13. **SplitSDFGToTargets** +#### **SplitSDFGToTargets** - **Purpose**: The final lowering step splits the single SDFG into an SDFG per target file. This means that, for example, a GPU kernel map will be converted to an ExternalSDFG call to another SDFG file that contains the kernel. @@ -199,13 +206,13 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s ### Phase 3: Code Generation Passes -#### 14. **GenerateStateStruct** +#### **GenerateStateStruct** - **Purpose**: Generate state struct definitions for persistent data - **Input**: SDFG with allocation info - **Output**: pipeline_results["state_struct"] = {struct_def, struct_init} - **Current Location**: `DaCeCodeGenerator.generate_code()` -#### 15. **GenerateTargetCode** +#### **GenerateTargetCode** - **Purpose**: Generate both frame code and target-specific code for each SDFG file by traversing the graph and emitting code for each element. - **Input**: Split SDFGs with all previous analyses @@ -214,7 +221,7 @@ The `DaCeCodeGenerator` class currently handles numerous responsibilities that s - **Note**: This pass may call individual target code generators (CppCodeGen, GPUCodeGen, FPGACodeGen, etc.) to generate platform-specific code -#### 14. **GenerateHeaders** +#### **GenerateHeaders** - **Purpose**: Generate C/C++ header files for SDFG interface - **Input**: CodeObjects with complete code - **Output**: pipeline_results["headers"] = {call_header, sample_main} @@ -246,6 +253,8 @@ class CodeGenerationPipeline(Pipeline): ConditionalPipeline([ (lambda r: 'cuda' in r.get('targets', []), CopyToMapPass()), (lambda r: 'fpga' in r.get('targets', []), FPGAPreprocessingPass()), + (lambda r: 'cuda' in r.get('targets', []), LowerConsumeGPU()), + (lambda r: 'cpu' in r.get('targets', []), LowerConsumeCPP()), ]), LowerTaskletLanguage(), LowerMemlets(), @@ -349,11 +358,11 @@ dace/codegen/ - Base for other C++ based backends - Sequential execution model - Basic memory management +- Current "CPU" backend functionality (without parallelism) #### 2. **OpenMP Backend** (`targets/openmp.py`) - Extends C++ backend with OpenMP directives - CPU parallelization via OpenMP -- Current "CPU" backend functionality - Shared memory parallelism #### 3. **GPU Backend** (`targets/gpu.py`) From 387f1e8719128d5dbc52e0db7ef531661e13e3a8 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 11 Dec 2025 19:48:20 -0800 Subject: [PATCH 10/17] Update C++ standard to C++20 (#2253) --- dace/codegen/CMakeLists.txt | 4 ++++ dace/codegen/compiler.py | 1 + dace/config_schema.yml | 13 ++++++++++--- doc/setup/installation.rst | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index e1a5e33947..5a8e6438eb 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -7,6 +7,7 @@ set(DACE_PROGRAM_NAME "dace_program" CACHE STRING "Name of DaCe program") set(DACE_SRC_DIR "" CACHE STRING "Root directory of generated code files") set(DACE_FILES "" CACHE STRING "List of host code files relative to the root of the source directory") set(DACE_LIBS "" CACHE STRING "Extra libraries") +set(DACE_CPP_STANDARD "20" CACHE STRING "C++ standard to use for compilation (e.g., 14, 17, 20, 23, 26)") set(HLSLIB_PART_NAME "${DACE_XILINX_PART_NAME}") # CUDA @@ -566,6 +567,9 @@ include("targets/mlir/mlir.cmake") add_library(${DACE_PROGRAM_NAME} SHARED ${DACE_CPP_FILES} ${DACE_OBJECTS}) target_link_libraries(${DACE_PROGRAM_NAME} PUBLIC ${DACE_LIBS}) +# Set C++ standard to C++20 (or the configured standard) +set_property(TARGET ${DACE_PROGRAM_NAME} PROPERTY CXX_STANDARD ${DACE_CPP_STANDARD}) + # Add additional required files if(DACE_ENABLE_INTELFPGA) if(DACE_INTELFPGA_MODE STREQUAL "emulator") diff --git a/dace/codegen/compiler.py b/dace/codegen/compiler.py index 00f40da622..585cd66aee 100644 --- a/dace/codegen/compiler.py +++ b/dace/codegen/compiler.py @@ -170,6 +170,7 @@ def configure_and_compile(program_folder, program_name=None, output_stream=None) "-DDACE_SRC_DIR=\"{}\"".format(src_folder), "-DDACE_FILES=\"{}\"".format(";".join(files)), "-DDACE_PROGRAM_NAME={}".format(program_name), + "-DDACE_CPP_STANDARD={}".format(Config.get('compiler', 'cpp_standard')), ] # Get required environments are retrieve the CMake information diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 812e24329e..72e1f784f9 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -173,6 +173,13 @@ required: The typename of this struct is derived by appending this value to the SDFG's name. Note that the suffix may only contains letters, digits and underscores. + cpp_standard: + type: str + default: "20" + title: C++ standard version + description: > + C++ standard to use for compilation (e.g., 14, 17, 20, 23, 26). + format_code: type: bool default: false @@ -183,7 +190,7 @@ required: format_config_file: type: str default: "" - title: Path the clang-format file + title: Path to the .clang-format file description: > Clang-format file to be used by clang-format, only used if format_code is true @@ -261,7 +268,7 @@ required: type: str title: Arguments description: Compiler argument flags - default: '-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' + default: '-fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' default_Windows: '/O2 /fp:fast /arch:AVX2 /D_USRDLL /D_WINDLL /D__restrict__=__restrict' libs: @@ -310,7 +317,7 @@ required: type: str title: hipcc Arguments description: Compiler argument flags for HIP - default: '-std=c++17 -fPIC -O3 -ffast-math -Wno-unused-parameter' + default: '-fPIC -O3 -ffast-math -Wno-unused-parameter' cuda_arch: type: str diff --git a/doc/setup/installation.rst b/doc/setup/installation.rst index e8dbc4d570..b76a41565f 100644 --- a/doc/setup/installation.rst +++ b/doc/setup/installation.rst @@ -14,7 +14,7 @@ Most dependencies will be resolved when the package is installed with ``pip`` or however, it requires two more runtime dependencies to be installed and available in the ``PATH`` environment variable (if not, see :ref:`config` for how to configure different compiler paths): - * A C++14-capable compiler (e.g., gcc 5.3+) + * A C++20-capable compiler (e.g., gcc 10+) * CMake 3.15 or newer. *Note: if CMake cannot be found or is too old, pip will try to install a version but it sometimes fails.* From abc71e9f1722edb1639eeb562320a9a1376f286e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:28:52 +0100 Subject: [PATCH 11/17] Squashed commit of the following: commit ecb2785ad28d789b889c807f918b45e9242a0fe7 Author: Philip Mueller Date: Wed Dec 17 08:19:40 2025 +0100 Updated the dace updater workflow file. commit f3198efd81b9b8e4311955dcccffbe28a63b5e7d Author: Philip Mueller Date: Wed Dec 17 07:41:26 2025 +0100 Made the update point to the correct repo. commit 96f963a122b6109bb3df3d61afebd15a06305ef2 Merge: 8b7cce58d 387f1e871 Author: Philip Mueller Date: Wed Dec 17 07:37:48 2025 +0100 Merge remote-tracking branch 'spcl/main' into automatic_gt4py_deployment commit 8b7cce58d62305444886d4042706da4f7d1dd872 Author: Philip Mueller Date: Mon Dec 1 09:18:22 2025 +0100 Restored the original workflow files. commit 362ab70099d6a5488a271b89c817724708cf21fe Author: Philip Mueller Date: Mon Dec 1 07:41:40 2025 +0100 Now it has run once, so let's make it less runnable. commit 81b8cfa08f9cc3968bcfd95477bcbb90eb8c930e Author: Philip Mueller Date: Mon Dec 1 07:39:09 2025 +0100 Made it run always. commit 6d71466ea70df4a1dff9ac6ba0efe4d208f5c422 Author: Philip Mueller Date: Mon Dec 1 07:38:11 2025 +0100 Small update. commit eb31e6cc4a0f5016f22cf525ec59f13cbc05ccf0 Author: Philip Mueller Date: Fri Nov 21 15:23:33 2025 +0100 Empty commit in the branch containing the workflow file. commit 2970a75f02e8b12810cc9a732459955b4673edee Author: Philip Mueller Date: Fri Nov 21 15:21:09 2025 +0100 Next step. commit f5d3d9df89822199684e13003ef1fe639fd168e2 Author: Philip Mueller Date: Fri Nov 21 15:17:56 2025 +0100 Let's disable everything. commit 211e415444a62c36cb46016368c33ea4049335f9 Author: Philip Mueller Date: Fri Nov 21 15:10:43 2025 +0100 Disabled the kickstarter. commit d012c26261b18fd7eccae5c1b6bccb150a990d78 Author: Philip Mueller Date: Fri Nov 21 15:05:38 2025 +0100 Updated everything. --- .github/workflows/dace-updater.yml | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 .github/workflows/dace-updater.yml diff --git a/.github/workflows/dace-updater.yml b/.github/workflows/dace-updater.yml new file mode 100644 index 0000000000..aa60eeb193 --- /dev/null +++ b/.github/workflows/dace-updater.yml @@ -0,0 +1,48 @@ +name: Inform the Python package index about a new DaCe release. + +on: + # Trigger for all pushes to tags matching this pattern + push: + tags: + - __gt4py-next-integration_* + + # To "install" this workflow you must enable this trigger, such that the workflow runs at least one. + # You should also disable any processing such that no commit in the index repo is performed. + # See https://stackoverflow.com/a/71057825 + #pull_request: + + # Allows to trigger the update manually. + # NOTE: Is only possible if the workflow file is located on the default and the branch where it should run on. + workflow_dispatch: + +jobs: + update-dace: + runs-on: ubuntu-latest + steps: + - name: Inform Index + shell: bash + run: | + INDEX_ORGANIZATION="gridtools" + INDEX_REPO="python-pkg-index" + + # We are using `github.sha` here to be sure that we transmit an identifier to the index + # that can be checked out. Before we used `github.ref_name` but got strange results + # with it. + DEPENDENCY_REF="${{ github.sha }}" + SOURCE_REPO="dace" + SOURCE_OWNER="gridtools" + + curl -L -v --fail-with-body \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.PKG_UPDATE_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "https://api.github.com/repos/${INDEX_ORGANIZATION}/${INDEX_REPO}/dispatches" \ + -d '{"event_type":"update_package_index","client_payload":{"source_repo":"'"${SOURCE_REPO}"'","source_org":"'"${SOURCE_OWNER}"'","dependency_ref":"'"${DEPENDENCY_REF}"'"}}' + + if [ $? -ne 0 ] + then + echo "POST to '${INDEX_ORGANIZATION}:${INDEX_REPO}' failed. + exit 1 + fi + exit 1 From 8321782388232d799f911f87506bdbeff089aef9 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:29:07 +0100 Subject: [PATCH 12/17] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 2a89832d1792cd17e93e7d6506e1ae0f7d4c188e Author: Ioannis Magkanaris Date: Mon Oct 27 09:56:16 2025 +0100 Fixed GPU_TX_MARKER test commit c2401285f4cbc4edab9c267fa1482be018ab61fb Merge: 10160bc1c e38d00617 Author: Ioannis Magkanaris Date: Fri Oct 24 18:29:55 2025 +0200 Merge remote-tracking branch 'upstream/main' into nvtx_ranges commit 10160bc1cf08b1ae8db4f8e3b4a232dc9f4fce1c Author: Ioannis Magkanaris Date: Fri Oct 24 18:27:58 2025 +0200 Fix instrumentation for copies commit d14093c7d8ae761f864738583747b3377a79a68d Author: Ioannis Magkanaris Date: Mon Oct 6 16:16:43 2025 +0200 Make pre-commit happy commit 68942a384e3798b9428f3f00dac2d63d9314f822 Merge: a3063e57e b415f6263 Author: Ioannis Magkanaris Date: Tue Sep 30 17:11:35 2025 +0200 Merge remote-tracking branch 'upstream/main' into nvtx_ranges commit a3063e57e0e10d0c73f68e2a6df92103cfa0d62d Author: Ioannis Magkanaris Date: Tue Sep 23 18:03:26 2025 +0200 Working version of nvtx markers with allocations commit 6788b9712f1fa88baf11838f80cb71ec7453d2c3 Author: Ioannis Magkanaris Date: Tue Sep 23 13:42:41 2025 +0200 Updated functions commit 455ad38ad3169f972dc7b8a0ab7871cae83b8720 Author: Ioannis Magkanaris Date: Tue Sep 23 13:20:51 2025 +0200 Added marker on allocations as well commit 80ce99cd8b2be2f9de2e096fec327be2d17dfadc Author: Ioannis Magkanaris Date: Wed Aug 20 19:02:36 2025 +0300 Avoid profiling tasklets commit 0314386fcc1e86c483c414d01c417c150e168703 Author: Ioannis Magkanaris Date: Wed Aug 20 19:02:28 2025 +0300 Fix get_latest_report_path in case there's no report commit aad5e877805afff6f581d57f361b4bd2b0be4ad2 Author: Ioannis Magkanaris Date: Wed Aug 20 10:05:16 2025 +0200 Remove import of deleted file commit a3ff00eeb1b4e9fe65272ba36fe10391b6bd9ec4 Author: Ioannis Magkanaris Date: Tue Aug 19 18:16:11 2025 +0200 Revert "Improved GPU Copy (#1976)" This reverts commit bc83c4750fe4b84beee8bd3bf7f1d42ae9f587cb. commit ea5f6ffa705a211833abbcaefb7502993293c4c7 Author: Ioannis Magkanaris Date: Tue Aug 19 18:14:35 2025 +0200 Make format happy commit b1ea9afc24d28e7772c7a0a93f2ed60e85fff538 Merge: bbc1fafa0 aabbe4821 Author: Ioannis Magkanaris Date: Tue Aug 19 19:12:43 2025 +0300 Merge branch 'main' into nvtx_ranges commit bbc1fafa0bbfe667f7dd61cfeeb6bbb356172e96 Author: Ioannis Magkanaris Date: Tue Aug 19 18:07:12 2025 +0200 Format a bit better with dace.instrument commit eea658f3e1e64bd2ee89fa44f704752ea162c73d Author: Ioannis Magkanaris Date: Tue Aug 19 18:04:54 2025 +0200 Fixes in gpu_tx_markers.py commit 2f43f7a90c801af70f8ad3c8aeb0092fe6696c4d Author: Ioannis Magkanaris Date: Tue Aug 19 18:04:43 2025 +0200 Remove instrument_sdfg commit 0fdb4df3a2b2cd7823fcfb27b706783204902d55 Author: Ioannis Magkanaris Date: Tue Aug 19 17:28:57 2025 +0200 Small refactoring of if statements in gpu_tx_markers.py commit 73c52bf5825ddc2b24cd30999a96d0f2d02851bb Author: Ioannis Magkanaris Date: Tue Aug 19 17:26:02 2025 +0200 Added on_sdfg_init/exit_begin/end functions commit ff70f2f029c60092d07bc20ba30c4c9582785686 Author: Ioannis Magkanaris Date: Tue Aug 19 16:32:34 2025 +0200 Replaced is with == commit 3d626e06866919a9d0b00d9f9c834216a104597f Author: Ioannis Magkanaris Date: Tue Aug 19 16:31:04 2025 +0200 Fix local and global streams commit 209860deced08dcc3f84285a16ccd2787be6a755 Author: Ioannis Magkanaris Date: Tue Aug 19 16:27:04 2025 +0200 Improve _is_sdfg_in_device_code commit bc83c4750fe4b84beee8bd3bf7f1d42ae9f587cb Author: Philip Müller <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon Jun 2 15:58:08 2025 +0200 Improved GPU Copy (#1976) Before some 2D copies (especially if they had FORTRAN order) were turned into Maps, see [issue#1953](https://github.com/spcl/dace/issues/1953). This PR modifies the code generator in such a way that such copies are now handled. There is some legacy stuff that should also be looked at. --------- Co-authored-by: Philip Mueller Co-authored-by: Tal Ben-Nun commit df9957125d4bbb97c6e8d453d7fd26db472ff6b4 Author: Ioannis Magkanaris Date: Tue Aug 19 17:34:12 2025 +0300 Apply suggestion from @tbennun Co-authored-by: Tal Ben-Nun commit da00f21b651bead1c061b3b54bff43e9142fd47c Author: Ioannis Magkanaris Date: Mon May 19 17:49:47 2025 +0200 Avoid pushing rocTX markers before initializing HIP since it doesn't work commit a39308bc7b9105ac02b139baf43eb29f78ff50bd Author: Ioannis Magkanaris Date: Fri May 16 15:13:31 2025 +0200 Fix on_copy and on_scope for GPU_TX_MARKERS commit 2d554fa26f2441ccdb9a52c073454018b6bc291a Author: Ioannis Magkanaris Date: Thu May 15 15:20:05 2025 +0200 Removed preprocessor checks by properly placing ranges in NestedSDFGs and small fixes for CPU wrapper includes commit 5937a153d212c3600cf8db38056163c440efeab6 Author: Ioannis Magkanaris Date: Wed May 14 11:33:02 2025 +0200 Refactored a bit GPUTXMarkerProvider commit 9e8ec9ea6b5013e0fad941995bef2923370733ff Author: Ioannis Magkanaris Date: Wed May 14 10:52:26 2025 +0200 Addressed PR comments for checking is the instrumentation is enabled commit c3f1932571f69d32a00d6a1ad7619b4ceb8e9e1a Author: Ioannis Magkanaris Date: Mon May 12 17:29:30 2025 +0200 Small fixes and cleanups commit 366721f28e4b5a01616fa6f251d46d54d61adefc Author: Ioannis Magkanaris Date: Mon May 12 17:23:21 2025 +0200 Fix order of imports in gpu_events.py commit 8ea432725a15704b6588860e39353dcf2f76b228 Author: Ioannis Magkanaris Date: Mon May 12 17:04:56 2025 +0200 Add markers for different SDFGs and states commit 22b372eed787aa621cdcd541eb59206a8952c6be Author: Ioannis Magkanaris Date: Mon May 12 09:45:20 2025 +0200 Revert changes in GPU_Event provider commit e5adaefc9773627ad0c142adbe98d1d64f926aa4 Author: Ioannis Magkanaris Date: Mon May 12 09:34:34 2025 +0200 Allow building with HIP even if rocTX is not found commit b30f4a268a696922889f8025f2d8fd61e9ac624e Author: Ioannis Magkanaris Date: Fri May 9 17:20:34 2025 +0200 Fix formatting commit 747f357a54cff0d5b4b0b0e35c26b7bd1ab04837 Author: Ioannis Magkanaris Date: Fri May 9 17:14:17 2025 +0200 Made test NVTX agnostic and updated documentation commit 646ca9053f5aaa30a97a0cdbb2545f1b47605036 Author: Ioannis Magkanaris Date: Fri May 9 17:05:10 2025 +0200 Use same checks for enabling roctx as CMake commit c28036ba99a5b6092ed126821b2bc39b4e785f0f Author: Ioannis Magkanaris Date: Fri May 9 17:00:19 2025 +0200 Fix compilation for AMD gpu commit 855304d329c6f660fb2bf61b5f9ba5d1b1ee5205 Author: Ioannis Magkanaris Date: Thu May 8 11:58:00 2025 +0200 Fix library names commit 9df4f73bf900c2775cbce381a2abe3606c319baa Author: Ioannis Magkanaris Date: Thu May 8 11:36:29 2025 +0200 Trying to use roctx commit a55aeb7ecdc98e5223bbad903fc557354deb54a8 Author: Ioannis Magkanaris Date: Wed May 7 17:58:37 2025 +0200 Make formatting happy commit a8bcadf175cfde856ec61f8af63947f623002df7 Author: Ioannis Magkanaris Date: Wed May 7 17:50:10 2025 +0200 Renamed NVTX to GPU_TX_MARKERS and added note for AMD GPUs commit 7337233db63b37df28aca68a91f1c8b0b53e7ce1 Author: Ioannis Magkanaris Date: Mon May 5 17:30:35 2025 +0200 Changed nvtxRangePushA to nvtxRangePush commit 74c9117b90500513ded30721ca19cb5b53429c0e Author: Ioannis Magkanaris Date: Mon May 5 17:23:42 2025 +0200 Fix copyright and GPU test commit 989bc3273af779d33d20b28a6b49a30f33b287b1 Author: Ioannis Magkanaris Date: Mon May 5 17:12:59 2025 +0200 Make formatter happy commit 4f572974e4cfc4f63f338c8ccb77dea161a128bb Author: Ioannis Magkanaris Date: Mon May 5 17:09:58 2025 +0200 Remove NVTX markers from LIKWID since LIKWID has its own markers commit a4d2ff8a47997250faa3790088c1ad8be5cdd311 Author: Ioannis Magkanaris Date: Mon May 5 17:08:08 2025 +0200 Improved NVTX markers in likwid commit 1e71171b4a3bc2a7e091fc7f8d8c01af52e363b7 Author: Ioannis Magkanaris Date: Mon May 5 15:42:13 2025 +0200 Update NVTX Provider imports commit 438090f847d8016cbb64edca30cffb7239997164 Author: Ioannis Magkanaris Date: Mon May 5 15:41:56 2025 +0200 Update documentation commit 89b7864eab3046ff1e7a6e7cdb4816c2731ec8dd Author: Ioannis Magkanaris Date: Mon May 5 15:41:48 2025 +0200 Small fix of whiteline in framecode commit ef5355b4a2627aa4267c8b4bc82007f2e0f1f793 Author: Ioannis Magkanaris Date: Mon May 5 15:38:02 2025 +0200 Refactored NVTX Instrumentation provider constructor and test for expected code commit bbf1d3218a4ca80c10c648f001dd6d7ca88730df Author: Ioannis Magkanaris Date: Mon May 5 15:37:16 2025 +0200 Inherit LIKWID_GPU Instrumentation provider from NVTX as well commit 90b50ac61d7e686e0389ad17587fa274a1730a10 Author: Ioannis Magkanaris Date: Fri May 2 18:29:07 2025 +0200 Make GPUEventProvider inherit from NVTXProvider to enable the NVTX markers by default with it commit c584255a3c24b3f29debcbc122990e173a780ef8 Author: Ioannis Magkanaris Date: Fri May 2 18:01:31 2025 +0200 Updated documentation commit 04836fb005f4368c77731bc01ed3b846d5e97c26 Author: Ioannis Magkanaris Date: Fri May 2 18:01:21 2025 +0200 Moved the printing of NVTX range push and pop inside the NVTXProvider commit f5240b2274f3fbc21ea85e3686aaa5d3e0e5e07f Author: Ioannis Magkanaris Date: Fri May 2 17:25:04 2025 +0200 Added NVTX range in CPU wrapper for GPU kernel --- dace/builtin_hooks.py | 3 +- dace/codegen/CMakeLists.txt | 17 +- dace/codegen/instrumentation/__init__.py | 1 + .../codegen/instrumentation/gpu_tx_markers.py | 255 ++++++++++++++++++ dace/codegen/instrumentation/provider.py | 72 +++++ dace/codegen/targets/cuda.py | 9 + dace/codegen/targets/framecode.py | 58 +++- dace/dtypes.py | 1 + dace/sdfg/sdfg.py | 8 +- doc/optimization/profiling.rst | 3 +- doc/source/dace.codegen.instrumentation.rst | 8 + tests/instrumentation_test.py | 36 ++- 12 files changed, 455 insertions(+), 16 deletions(-) create mode 100644 dace/codegen/instrumentation/gpu_tx_markers.py diff --git a/dace/builtin_hooks.py b/dace/builtin_hooks.py index 2a5b49e983..b691cd0296 100644 --- a/dace/builtin_hooks.py +++ b/dace/builtin_hooks.py @@ -96,7 +96,8 @@ def _make_filter_function(filter: Optional[Union[str, Callable[[Any], bool]]], if isinstance(filter, str): # If a string was given, construct predicate based on wildcard name matching if with_attr: - filter_func = lambda elem: fnmatch.fnmatch(elem.name, filter) + filter_func = lambda elem: fnmatch.fnmatch(elem.name, filter) if hasattr(elem, 'name') else fnmatch.fnmatch( + elem.label, filter) else: filter_func = lambda elem: fnmatch.fnmatch(elem, filter) elif callable(filter): diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index 5a8e6438eb..bcc683a063 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -142,7 +142,7 @@ if(DACE_ENABLE_CUDA) set(CMAKE_CUDA_ARCHITECTURES "${LOCAL_CUDA_ARCHITECTURES}") enable_language(CUDA) - list(APPEND DACE_LIBS CUDA::cudart) + list(APPEND DACE_LIBS CUDA::cudart CUDA::nvtx3) add_definitions(-DWITH_CUDA) if (MSVC_IDE) @@ -168,6 +168,21 @@ if(DACE_ENABLE_HIP) # Add libraries such as rocBLAS link_directories(${HIP_PATH}/../lib) + if(ROCM_PATH) + find_path(ROCTX_INCLUDE_DIR roctx.h HINTS ${ROCM_PATH}/include/roctracer ${ROCM_PATH}/roctracer/include) + if(NOT ROCTX_INCLUDE_DIR) + message(WARNING "Could not find roctx.h in ${ROCM_PATH}/include/roctracer or ${ROCM_PATH}/roctracer/include") + endif() + endif() + if(ROCM_PATH AND ROCTX_INCLUDE_DIR) + find_path(ROCTX_LIBRARY_DIR "libroctx64.so" HINTS ${ROCM_PATH}/lib) + if(NOT ROCTX_LIBRARY_DIR) + message(WARNING "Could not find libroctx64.so in ${ROCM_PATH}/lib") + else() + list(APPEND DACE_LIBS "-lroctx64 -L${ROCTX_LIBRARY_DIR}") + include_directories(SYSTEM ${ROCTX_INCLUDE_DIR}) + endif() + endif() endif() # Function for performing deferred variable expansion diff --git a/dace/codegen/instrumentation/__init__.py b/dace/codegen/instrumentation/__init__.py index d357e1a5a3..5ebab3f497 100644 --- a/dace/codegen/instrumentation/__init__.py +++ b/dace/codegen/instrumentation/__init__.py @@ -7,5 +7,6 @@ from .timer import TimerProvider from .gpu_events import GPUEventProvider from .fpga import FPGAInstrumentationProvider +from .gpu_tx_markers import GPUTXMarkersProvider from .data.data_dump import SaveProvider, RestoreProvider diff --git a/dace/codegen/instrumentation/gpu_tx_markers.py b/dace/codegen/instrumentation/gpu_tx_markers.py new file mode 100644 index 0000000000..be94e425fe --- /dev/null +++ b/dace/codegen/instrumentation/gpu_tx_markers.py @@ -0,0 +1,255 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import os +from typing import Union + +from dace import dtypes, registry +from dace.codegen import common +from dace.codegen.prettycode import CodeIOStream +from dace.codegen.instrumentation.provider import InstrumentationProvider +from dace.memlet import Memlet +from dace.sdfg import nodes, SDFG +from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.nodes import NestedSDFG +from dace.sdfg.scope import is_devicelevel_gpu_kernel +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion, SDFGState + + +@registry.autoregister_params(type=dtypes.InstrumentationType.GPU_TX_MARKERS) +class GPUTXMarkersProvider(InstrumentationProvider): + """ Timing instrumentation that adds NVTX/rocTX ranges to SDFGs and states. """ + NVTX_HEADER_INCLUDE = '#include ' + ROCTX_HEADER_INCLUDE = '#include ' + + def __init__(self): + self.backend = common.get_gpu_backend() + # Check if ROCm TX libraries and headers are available + rocm_path = os.getenv('ROCM_PATH', '/opt/rocm') + roctx_header_paths = [ + os.path.join(rocm_path, 'roctracer/include/roctx.h'), + os.path.join(rocm_path, 'include/roctracer/roctx.h') + ] + roctx_library_path = os.path.join(rocm_path, 'lib', 'libroctx64.so') + self.enable_rocTX = any(os.path.isfile(path) + for path in roctx_header_paths) and os.path.isfile(roctx_library_path) + self.include_generated = False + super().__init__() + + def _print_include(self, sdfg: SDFG) -> None: + """ Prints the include statement for the NVTX/rocTX library for a given SDFG. """ + if self.include_generated: + return + if self.backend == 'cuda': + sdfg.append_global_code(self.NVTX_HEADER_INCLUDE, 'frame') + elif self.backend == 'hip': + if self.enable_rocTX: + sdfg.append_global_code(self.ROCTX_HEADER_INCLUDE, 'frame') + else: + raise NameError('GPU backend "%s" not recognized' % self.backend) + self.include_generated = True + + def print_include(self, stream: CodeIOStream) -> None: + """ Prints the include statement for the NVTX/rocTX library in stream. """ + if stream is None: + return + if self.backend == 'cuda': + stream.write(self.NVTX_HEADER_INCLUDE) + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write(self.ROCTX_HEADER_INCLUDE) + else: + raise NameError('GPU backend "%s" not recognized' % self.backend) + + def print_range_push(self, name: str, sdfg: SDFG, stream: CodeIOStream) -> None: + if stream is None: + return + self._print_include(sdfg) + if name is None: + name = 'None' + if self.backend == 'cuda': + stream.write(f'nvtxRangePush("{name}");') + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write(f'roctxRangePush("{name}");') + else: + raise NameError(f'GPU backend "{self.backend}" not recognized') + + def print_range_pop(self, stream: CodeIOStream) -> None: + if stream is None: + return + if self.backend == 'cuda': + stream.write('nvtxRangePop();') + elif self.backend == 'hip': + if self.enable_rocTX: + stream.write('roctxRangePop();') + else: + raise NameError(f'GPU backend "{self.backend}" not recognized') + + def _is_sdfg_in_device_code(self, sdfg: SDFG) -> bool: + """ Check if the SDFG is in device code and not top level SDFG. """ + sdfg_parent_state = sdfg.parent + while sdfg_parent_state is not None: + sdfg_parent_node = sdfg.parent_nsdfg_node + if is_devicelevel_gpu_kernel(sdfg, sdfg_parent_state, sdfg_parent_node): + return True + sdfg_parent_state = sdfg_parent_state.sdfg.parent + return False + + def on_sdfg_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream, codegen) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'sdfg_{sdfg.name}', sdfg, local_stream) + + def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'state_{state.label}', sdfg, local_stream) + + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_copy_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream, copy_shape, src_strides, dst_strides) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, src_node) or is_devicelevel_gpu_kernel(sdfg, state, dst_node): + # Don't instrument device code + return + self.print_range_push(f'copy_{src_node.label}_to_{dst_node.label}', sdfg, local_stream) + + def on_copy_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: + if state.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, src_node) or is_devicelevel_gpu_kernel(sdfg, state, dst_node): + # Don't instrument device code + return + self.print_range_pop(local_stream) + + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if node.map.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, node): + # Don't instrument device code + return + self.print_range_push(f'scope_{node.label}', sdfg, outer_stream) + + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + entry_node = state.entry_node(node) + if entry_node.map.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if is_devicelevel_gpu_kernel(sdfg, state, entry_node): + # Don't instrument device code + return + self.print_range_pop(outer_stream) + + def on_sdfg_init_begin(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + # cannot push rocTX markers before initializing HIP + if self.enable_rocTX: + return + self.print_range_push(f'init_{sdfg.name}', sdfg, callsite_stream) + + def on_sdfg_init_end(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + # cannot push rocTX markers before initializing HIP so there's no marker to pop + if self.enable_rocTX: + return + self.print_range_pop(callsite_stream) + + def on_sdfg_exit_begin(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'exit_{sdfg.name}', sdfg, callsite_stream) + + def on_sdfg_exit_end(self, sdfg: SDFG, callsite_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(callsite_stream) + + def on_allocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'alloc_{sdfg.name}', sdfg, stream) + + def on_allocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(stream) + + def on_deallocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_push(f'dealloc_{sdfg.name}', sdfg, stream) + + def on_deallocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + if sdfg.instrument != dtypes.InstrumentationType.GPU_TX_MARKERS: + return + # We only want to instrument allocations at the SDFG or state level + if not isinstance(scope, (SDFGState, SDFG)): + return + if self._is_sdfg_in_device_code(sdfg): + # Don't instrument device code + return + self.print_range_pop(stream) diff --git a/dace/codegen/instrumentation/provider.py b/dace/codegen/instrumentation/provider.py index a95c0495ba..dc643df4ca 100644 --- a/dace/codegen/instrumentation/provider.py +++ b/dace/codegen/instrumentation/provider.py @@ -183,3 +183,75 @@ def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node :param global_stream: Code generator for global (external) code. """ pass + + def on_sdfg_init_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the beginning of SDFG initialization code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_init_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the end of SDFG initialization code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_exit_begin(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the beginning of SDFG exit code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_sdfg_exit_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + """ Event called at the end of SDFG exit code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator for the in-function code. + :param global_stream: Code generator for global (external) code. + """ + pass + + def on_allocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + """ Event called at the beginning of an allocation code generation. + + :param sdfg: The generated SDFG object. + :param stream: Code generator. + """ + pass + + def on_allocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + lstream: CodeIOStream) -> None: + """ Event called at the end of an allocation code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator. + """ + pass + + def on_deallocation_begin(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + stream: CodeIOStream) -> None: + """ Event called at the beginning of a deallocation code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator. + """ + pass + + def on_deallocation_end(self, sdfg: SDFG, scope: Union[nodes.EntryNode, SDFGState, SDFG], + lstream: CodeIOStream) -> None: + """ Event called at the end of a deallocation code generation. + + :param sdfg: The generated SDFG object. + :param local_stream: Code generator. + """ + pass diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 6de2cc0769..55e2fe3241 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1088,6 +1088,11 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St is_c_order = is_fortran_order dims = 1 + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_copy_begin(sdfg, cfg, state_dfg, src_node, dst_node, edge, callsite_stream, None, + copy_shape, src_strides, dst_strides) + if dims > 2: # Currently we only support ND copies when they can be represented # as a 1D copy or as a 2D strided copy @@ -1243,6 +1248,10 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St self._emit_sync(callsite_stream) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_copy_end(sdfg, cfg, state_dfg, src_node, dst_node, edge, callsite_stream, None) + # Copy within the GPU elif (src_storage in gpu_storage_types and dst_storage in gpu_storage_types): diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 449e312efa..aecd78a092 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -14,6 +14,7 @@ from dace.codegen import dispatcher as disp from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp +from dace.codegen.instrumentation.gpu_tx_markers import GPUTXMarkersProvider from dace.codegen.targets.target import TargetCodeGenerator from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import SDFG, SDFGState, nodes @@ -254,6 +255,13 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre # Write closing brace of program callsite_stream.write('}', sdfg) + if sdfg.instrument == dtypes.InstrumentationType.GPU_TX_MARKERS: + # Need to make sure that the necessary includes for GPU_TX_MARKERS are present + # in the generated code. + gpu_tx_markers_provider = self._dispatcher.instrumentation.get(dtypes.InstrumentationType.GPU_TX_MARKERS) + if gpu_tx_markers_provider: + gpu_tx_markers_provider.print_include(callsite_stream) + # Write awkward footer to avoid 'extern "C"' issues params_comma = (', ' + params) if params else '' initparams_comma = (', ' + initparams) if initparams else '' @@ -279,11 +287,17 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write( f""" DACE_EXPORTED {mangle_dace_state_struct_name(sdfg)} *__dace_init_{sdfg.name}({initparams}) -{{ - int __result = 0; - {mangle_dace_state_struct_name(sdfg)} *__state = new {mangle_dace_state_struct_name(sdfg)}; +{{""", sdfg) - """, sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_init_begin(sdfg, callsite_stream, global_stream) + + callsite_stream.write( + f""" + int __result = 0; + {mangle_dace_state_struct_name(sdfg)} *__state = new {mangle_dace_state_struct_name(sdfg)};""", sdfg) for target in self._dispatcher.used_targets: if target.has_initializer: @@ -304,17 +318,29 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write(self._initcode.getvalue(), sdfg) - callsite_stream.write( - f""" + callsite_stream.write(f""" if (__result) {{ delete __state; return nullptr; }} +""", sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_init_end(sdfg, callsite_stream, global_stream) + callsite_stream.write( + f""" return __state; }} DACE_EXPORTED int __dace_exit_{sdfg.name}({mangle_dace_state_struct_name(sdfg)} *__state) {{ +""", sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_exit_begin(sdfg, callsite_stream, global_stream) + callsite_stream.write(f""" int __err = 0; """, sdfg) @@ -349,6 +375,10 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre callsite_stream.write("}") callsite_stream.write('delete __state;\n', sdfg) + # Invoke all instrumentation providers + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_sdfg_exit_end(sdfg, callsite_stream, global_stream) callsite_stream.write('return __err;\n}\n', sdfg) def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeIOStream): @@ -798,6 +828,11 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): def allocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Union[nodes.EntryNode, SDFGState, SDFG], function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: + if len(self.to_allocate[scope]) == 0: + return + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_allocation_begin(sdfg, scope, callsite_stream) """ Dispatches allocation of all arrays in the given scope. """ for tsdfg, state, node, declare, allocate, _ in self.to_allocate[scope]: if state is not None: @@ -809,10 +844,18 @@ def allocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Un self._dispatcher.dispatch_allocate(tsdfg, cfg if state is None else state.parent_graph, state, state_id, node, desc, function_stream, callsite_stream, declare, allocate) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_allocation_end(sdfg, scope, callsite_stream) def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: Union[nodes.EntryNode, SDFGState, SDFG], function_stream: CodeIOStream, callsite_stream: CodeIOStream): + if len(self.to_allocate[scope]) == 0: + return + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_deallocation_begin(sdfg, scope, callsite_stream) """ Dispatches deallocation of all arrays in the given scope. """ for tsdfg, state, node, _, _, deallocate in self.to_allocate[scope]: if not deallocate: @@ -826,6 +869,9 @@ def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: self._dispatcher.dispatch_deallocate(tsdfg, state.parent_graph, state, state_id, node, desc, function_stream, callsite_stream) + for instr in self._dispatcher.instrumentation.values(): + if instr is not None: + instr.on_deallocation_end(sdfg, scope, callsite_stream) def generate_code(self, sdfg: SDFG, diff --git a/dace/dtypes.py b/dace/dtypes.py index faadc84a50..28372087e9 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -167,6 +167,7 @@ class InstrumentationType(aenum.AutoNumberEnum): LIKWID_GPU = () GPU_Events = () FPGA = () + GPU_TX_MARKERS = () @undefined_safe_enum diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index bda9d8707e..e1ac82397d 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1030,8 +1030,12 @@ def get_latest_report_path(self) -> Optional[str]: :return: A path to the latest instrumentation report, or None if one does not exist. """ path = os.path.join(self.build_folder, 'perf') - files = [f for f in os.listdir(path) if f.startswith('report-')] - if len(files) == 0: + try: + files = [f for f in os.listdir(path) if f.startswith('report-')] + except FileNotFoundError: + return None + + if not files: return None return os.path.join(path, sorted(files, reverse=True)[0]) diff --git a/doc/optimization/profiling.rst b/doc/optimization/profiling.rst index 87539e87a8..3f53d4e324 100644 --- a/doc/optimization/profiling.rst +++ b/doc/optimization/profiling.rst @@ -121,7 +121,8 @@ Instrumentation can also collect performance counters on CPUs and GPUs using `LI The :class:`~dace.dtypes.InstrumentationType.LIKWID_Counters` instrumentation type can be configured to collect a wide variety of performance counters on CPUs and GPUs. An example use can be found in the `LIKWID instrumentation code sample `_. - +There is also the :class:`~dace.dtypes.InstrumentationType.GPU_TX_MARKERS` instrumentation type which wraps in NVTX or rocTX markers the DaCe program executed on the GPU. Important parts of the execution of the program on the GPU as the different states, SDFGs and initialization and finalization phases are marked with these markers. +These markers can be used to visualize and measure the GPU activity using the NVIDIA Nsight Systems or ROCm Systems profilers. Instrumentation file format ~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/dace.codegen.instrumentation.rst b/doc/source/dace.codegen.instrumentation.rst index d476090d6a..0fc5941097 100644 --- a/doc/source/dace.codegen.instrumentation.rst +++ b/doc/source/dace.codegen.instrumentation.rst @@ -4,6 +4,14 @@ dace.codegen.instrumentation package Submodules ---------- +dace.codegen.instrumentation.gpu_tx_markers module +----------------------------------------------- + +.. automodule:: dace.codegen.instrumentation.gpu_tx_markers + :members: + :undoc-members: + :show-inheritance: + dace.codegen.instrumentation.fpga module ---------------------------------------- diff --git a/tests/instrumentation_test.py b/tests/instrumentation_test.py index 2aa26edf36..120e025812 100644 --- a/tests/instrumentation_test.py +++ b/tests/instrumentation_test.py @@ -4,6 +4,7 @@ import pytest import numpy as np +import re import sys import dace @@ -39,14 +40,17 @@ def onetest(instrumentation: dace.InstrumentationType, size=128): if isinstance(node, nodes.MapEntry) and node.map.label == 'mult': node.map.instrument = instrumentation state.instrument = instrumentation - # Set Timer instrumentation on the whole SDFG - if instrumentation == dace.InstrumentationType.Timer: - sdfg.instrument = instrumentation - if instrumentation == dace.InstrumentationType.GPU_Events: + if instrumentation in [dace.InstrumentationType.GPU_Events, dace.InstrumentationType.GPU_TX_MARKERS]: sdfg.apply_transformations(GPUTransformSDFG) - sdfg(A=A, B=B, C=C, N=size) + with dace.instrument(instrumentation, + filter='*', + annotate_maps=True, + annotate_tasklets=False, + annotate_states=True, + annotate_sdfgs=True): + sdfg(A=A, B=B, C=C, N=size) # Check for correctness assert np.allclose(C, 20 * A @ B) @@ -57,6 +61,22 @@ def onetest(instrumentation: dace.InstrumentationType, size=128): report = sdfg.get_latest_report() print(report) + # Check that the NVTX/rocTX range wrapper is present in the generated CPU code + if instrumentation == dace.InstrumentationType.GPU_TX_MARKERS: + code = sdfg.generate_code()[0].clean_code + tx_include = re.search(r'#include <(nvtx3/nvToolsExt|roctx).h>', code) + assert tx_include is not None + range_push = re.search(r'(nvtx|roctx)RangePush\("sdfg', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("copy', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("state', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("alloc', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("dealloc', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("init', code) is not None + range_push &= re.search(r'(nvtx|roctx)RangePush\("exit', code) is not None + assert range_push + range_pop = re.search(r'(nvtx|roctx)RangePop\b', code) + assert range_pop is not None + def test_timer(): onetest(dace.InstrumentationType.Timer) @@ -73,8 +93,14 @@ def test_gpu_events(): onetest(dace.InstrumentationType.GPU_Events) +@pytest.mark.gpu +def test_gpu_tx_markers(): + onetest(dace.InstrumentationType.GPU_TX_MARKERS) + + if __name__ == '__main__': test_timer() test_papi() if len(sys.argv) > 1 and sys.argv[1] == 'gpu': test_gpu_events() + test_gpu_tx_markers() From cf72335833cfbfd4b0b3635a6afd629a80f3865f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:29:23 +0100 Subject: [PATCH 13/17] Squashed commit of the following: commit 68ffa3bfd2dc3bfed0ef0ae7f276c51d1779e142 Merge: c06954693 d99ad29c4 Author: Philipp Schaad Date: Sun Nov 30 06:17:32 2025 -0600 Merge branch 'main' into make_construct_args_public commit c069546935a19d7182aa9c848445b438947cc0df Merge: 41902c396 408a4819f Author: Philip Mueller Date: Tue Nov 4 11:22:14 2025 +0100 Merge remote-tracking branch 'spcl/main' into make_construct_args_public commit 41902c396f6189e47d8a177fa925a261df4d4065 Author: Philip Mueller Date: Fri Oct 31 16:01:26 2025 +0100 Fixed a bug. commit 65725f9f36f7982a430fcb663b2d1289124ae5af Author: Philip Mueller Date: Fri Oct 31 15:15:03 2025 +0100 This should be enough for bug compatibility. commit daf90e910c5ae776fd1082c7c91126c28aaf5928 Author: Philip Mueller Date: Fri Oct 31 12:58:25 2025 +0100 Updated the thing a bit more. commit 2ddabbdd51684ea691a6e59caae183e0df694768 Merge: 4da0c4ebb b44aeb061 Author: Philip Mueller Date: Fri Oct 31 12:54:19 2025 +0100 Merge remote-tracking branch 'spcl/main' into make_construct_args_public commit 4da0c4ebb32b1ff082241b865b6a1bbbedc85fee Author: Philip Mueller Date: Fri Oct 31 12:53:48 2025 +0100 Made some additional check. commit 69960ce1ce7e6839b5a279cdf551c5afc933d453 Author: Philip Mueller Date: Fri Oct 31 12:00:30 2025 +0100 Forgot to do this. commit 6e1a9ff33d3405c7ce099579ac44943ab784788f Merge: c1214fa6b 1bf217328 Author: Philip Mueller Date: Fri Oct 31 11:25:46 2025 +0100 Merge remote-tracking branch 'spcl/main' into make_construct_args_public commit c1214fa6bfa3ba80cfee6a3fde69c40e812deaea Author: Philip Mueller Date: Fri Oct 31 09:50:41 2025 +0100 Updated the tests and made it clear that you can not return a scalar from an SDFG. commit 9397a230a48508212b30ffc22173007f7f7a3111 Author: Philip Mueller Date: Fri Oct 31 09:40:29 2025 +0100 Implemented the proper handling of tuples of size one. commit e8d909e2939f1ed3ef2c9bfdb8b399eb608ae0ab Author: Philip Mueller Date: Fri Oct 31 09:30:48 2025 +0100 Removed that stupid sclar return value feature that CAN NOT WORK. However, I saw that it also, under the hood sometimes tests if the argument is a pyobject. Since that thing is a pointer it is possible and I implemented it for that. But it was again not implemented properly, since for the case when the return value is passed as a regular argument, it was not checking that, only for managed return values. commit ab110d2692d3bae873b347145875f4c01f515b57 Author: Philip Mueller Date: Fri Oct 31 09:24:45 2025 +0100 Updated the description. commit 899b2a09f38a414d7ea3a800d2bb586d630cf37e Author: Philip Mueller Date: Fri Oct 31 09:24:32 2025 +0100 Fixed some old stuff. commit 7f17e135f32ee9a43da0eb7c3872c305ef750170 Author: Philip Mueller Date: Fri Oct 31 09:08:49 2025 +0100 Fixed a bug, but in a way I do not like. commit c2c1116fc81b96a78898d1bc8650d0dc39b5dd29 Author: Philip Mueller Date: Fri Oct 31 08:40:47 2025 +0100 Removed a missleading comment. commit ded5df80e181405b394c0912f5c9bc563b351e41 Author: Philip Mueller Date: Fri Oct 31 08:04:38 2025 +0100 Made some refactoring to remove some strange DaCe behaviour. commit b029828376745e984b83f80862b61210b3547f51 Author: Philip Mueller Date: Fri Oct 31 08:02:28 2025 +0100 Fixed an issue in safe_call commit b09c9fc238b3605d3ab6e85cf0e2f12023e031b3 Author: Philip Mueller Date: Fri Oct 31 07:17:36 2025 +0100 Included the first bunch of Tal's changes. commit e138b0662223a6beebf2b48e8b6f0b0421dce362 Author: Philip Mueller Date: Thu Oct 30 15:12:23 2025 +0100 Made the 'passed as positional and named argument'-error more explicit. commit f901a3d3889cbeec5195eb7b1ecc35c777a726d1 Author: Philip Mueller Date: Thu Oct 30 15:05:00 2025 +0100 Fixed a bug in a unit test. Due to the refactoring the case that a variable is passed once as positional and as named argument is not detected and asserted. This test however, passed `a` always as positional argument and if `symbolic` is `True` also as named argument. commit 767260d41c05ef5afb3f7ceb34918cc704d1fd68 Author: Philip Mueller Date: Thu Oct 30 14:19:44 2025 +0100 Clarified a comment. commit 2b8123a7513893d8a536a6570dd79eef85c0dc8e Author: Philip Mueller Date: Thu Oct 30 13:56:20 2025 +0100 Made the construct argumt vector function publich and also refactored some things. --- dace/codegen/compiled_sdfg.py | 334 ++++++++++++++------- tests/codegen/external_memory_test.py | 2 +- tests/python_frontend/return_value_test.py | 41 ++- 3 files changed, 264 insertions(+), 113 deletions(-) diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index b80273ded3..4ba47ee326 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -5,7 +5,7 @@ import re import shutil import subprocess -from typing import Any, Callable, Dict, List, Tuple, Optional, Type, Union +from typing import Any, Callable, Dict, List, Tuple, Optional, Type, Union, Sequence import warnings import tempfile import pickle @@ -178,8 +178,28 @@ def __deepcopy__(self, memodict={}): class CompiledSDFG(object): """ A compiled SDFG object that can be called through Python. - Todo: - Scalar return values are not handled properly, this is a code gen issue. + Essentially this class makes an SDFG callable. Normally a user will not create it + directly but instead it is generated by some utilities such as `SDFG.compile()`. + + The class performs the following tasks: + - It ensures that the SDFG object is properly initialized, either by a direct + call to `initialize()` or the first time it is called. Furthermore, it will + also take care of the finalization if it does out of scope. + - It transforms Python arguments into C arguments. + + Technically there are two ways how the SDFG can be called, the first is using + `__call__()`, i.e. as a normal function. However, this will always processes + the arguments and does some error checking and is thus slow. The second way + is the advanced interface, which allows to decompose the calling into different + subset. For more information see `construct_arguments()`, `fast_call()` and + `convert_return_values()`. + + :note: In previous version the arrays used as return values were sometimes reused. + However, this was changed and every time `construct_arguments()` is called + new arrays are allocated. + :note: It is not possible to return scalars. Note that currently using scalars + as return values is a validation error. The only exception are (probably) + Python objects. """ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): @@ -188,9 +208,14 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): self._lib = lib self._initialized = False self._libhandle = ctypes.c_void_p(0) - self._lastargs = () self.do_not_execute = False + # Contains the pointer arguments that where used to call the SDFG, `__call__()` + # was used. It is also used by `get_workspace_size()`. + # NOTE: Using its content might be dangerous as only the pointers to arrays are + # stored. It is the users responsibility to ensure that they are valid. + self._lastargs = None + lib.load() # Explicitly load the library self._init = lib.get_symbol('__dace_init_{}'.format(sdfg.name)) self._init.restype = ctypes.c_void_p @@ -199,17 +224,27 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): self._cfunc = lib.get_symbol('__program_{}'.format(sdfg.name)) # Cache SDFG return values - self._create_new_arrays: bool = True self._return_syms: Dict[str, Any] = None + # It will contain the shape of the array or the name if the return array is passed as argument. self._retarray_shapes: List[Tuple[str, np.dtype, dtypes.StorageType, Tuple[int], Tuple[int], int]] = [] - self._retarray_is_scalar: List[bool] = [] + # Is only `True` if teh return value is a scalar _and_ a `pyobject`. + self._retarray_is_pyobject: List[bool] = [] self._return_arrays: List[np.ndarray] = [] self._callback_retval_references: List[Any] = [] # Avoids garbage-collecting callback return values + # If there are return values then this is `True` it is is a single value. Note that + # `False` either means that a tuple is returned or there are no return values. + # NOTE: Needed to handle the case of a tuple with one element. + self._is_single_value_ret: bool = False + if '__return' in self._sdfg.arrays: + assert not any(aname.startswith('__return_') for aname in self._sdfg.arrays.keys()) + self._is_single_value_ret = True + # Cache SDFG argument properties self._typedict = self._sdfg.arglist() self._sig = self._sdfg.signature_arglist(with_types=False, arglist=self._typedict) self._free_symbols = self._sdfg.free_symbols + self._constants = self._sdfg.constants self.argnames = argnames if self.argnames is None and len(sdfg.arg_names) != 0: @@ -296,12 +331,21 @@ def get_workspace_sizes(self) -> Dict[dtypes.StorageType, int]: """ Returns the total external memory size to be allocated for this SDFG. + Note that the function queries the sizes of the last call that was made by + `__call__()` or `initialize()`. Calls made by `fast_call()` or `safe_call()` + will not be considered. + :return: A dictionary mapping storage types to the number of bytes necessary to allocate for the SDFG to work properly. + :note: It is the users responsibility that all arguments, especially the array + arguments, remain valid between the call to `__call__()` or `initialize()` + and the call to this function. """ if not self._initialized: raise ValueError('Compiled SDFG is uninitialized, please call ``initialize`` prior to ' 'querying external memory size.') + if self._lastargs is None: + raise ValueError('To use `get_workspace_sizes()` `__call__()` or `initialize()` must be called before.') result: Dict[dtypes.StorageType, int] = {} for storage in self.external_memory_types: @@ -315,15 +359,24 @@ def set_workspace(self, storage: dtypes.StorageType, workspace: Any): """ Sets the workspace for the given storage type to the given buffer. + Note that the function queries the sizes of the last call that was made by + `__call__()` or `initialize()`. Calls made by `fast_call()` or `safe_call()` + will not be considered. + :param storage: The storage type to fill. :param workspace: An array-convertible object (through ``__[cuda_]array_interface__``, see ``array_interface_ptr``) to use for the workspace. + :note: It is the users responsibility that all arguments, especially the array + arguments, remain valid between the call to `__call__()` or `initialize()` + and the call to this function. """ if not self._initialized: raise ValueError('Compiled SDFG is uninitialized, please call ``initialize`` prior to ' 'setting external memory.') if storage not in self.external_memory_types: raise ValueError(f'Compiled SDFG does not specify external memory of {storage}') + if self._lastargs is None: + raise ValueError('To use `get_workspace_sizes()` `__call__()` or `initialize()` must be called before.') func = self._lib.get_symbol(f'__dace_set_external_memory_{storage.name}', None) ptr = dtypes.array_interface_ptr(workspace, storage) @@ -358,12 +411,13 @@ def initialize(self, *args, **kwargs): if self._initialized: return - if len(args) > 0 and self.argnames is not None: - kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)}) - # Construct arguments in the exported C function order - _, initargtuple = self._construct_args(kwargs) + callargtuple, initargtuple = self.construct_arguments(*args, **kwargs) self._initialize(initargtuple) + + # The main reason for setting `_lastargs` here is, to allow calls to `get_workspace_size()`. + self._lastargs = (callargtuple, initargtuple) + return self._libhandle def finalize(self): @@ -388,38 +442,34 @@ def __call__(self, *args, **kwargs): """ Forwards the Python call to the compiled ``SDFG``. - The order of the positional arguments is expected to be the same as in - the ``argnames`` member. The function will roughly perform the - following tasks: - - Change the order of the Python arguments into the one required by - the binary. - - Performing some basic sanity checks. - - Transforming the Python arguments into their ``C`` equivalents. - - Allocate the memory for the return values. - - Call the ``C` function. + The order of the positional arguments is expected to be the same as in the + ``argnames`` member. The function will perform the following tasks: + - Calling ``construct_arguments()`` and creating the argument vector and + allocating the memory for the return values. + - Performing the actual call by means of ``fast_call()``, with enabled error + checks. + - Then it will convert the return value into the expected format by means of + ``convert_return_values()`` and return that value. :note: The memory for the return values is only allocated the first time this function is called. Thus, this function will always return the same objects. To force the allocation of new memory you can call ``clear_return_values()`` in advance. """ - if self.argnames is None and len(args) != 0: - raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.") - elif len(args) > 0 and self.argnames is not None: - kwargs.update( - # `_construct_args` will handle all of its arguments as kwargs. - { - aname: arg - for aname, arg in zip(self.argnames, args) - }) - argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here. - # Return values are cached in `self._lastargs`. - return self.fast_call(argtuple, initargtuple, do_gpu_check=True) + argtuple, initargtuple = self.construct_arguments(*args, **kwargs) # Missing arguments will be detected here. + self._lastargs = (argtuple, initargtuple) + self.fast_call(argtuple, initargtuple, do_gpu_check=True) + return self.convert_return_values() def safe_call(self, *args, **kwargs): """ Forwards the Python call to the compiled ``SDFG`` in a separate process to avoid crashes in the main process. Raises an exception if the SDFG execution fails. + + Note the current implementation lacks the proper handling of return values. + Thus output can only be transmitted through inout arguments. """ + if any(aname == '__return' or aname.startswith('__return_') for aname in self.sdfg.arrays.keys()): + raise NotImplementedError('`CompiledSDFG.safe_call()` does not support return values.') # Pickle the SDFG and arguments with tempfile.NamedTemporaryFile(mode='wb', delete=False) as f: @@ -471,24 +521,25 @@ def safe_call(self, *args, **kwargs): def fast_call( self, - callargs: Tuple[Any, ...], - initargs: Tuple[Any, ...], + callargs: Sequence[Any], + initargs: Sequence[Any], do_gpu_check: bool = False, - ) -> Union[Tuple[Any, ...], Any]: + ) -> None: """ - Calls the underlying binary functions directly and bypassing - argument sanitation. + Calls the underlying binary functions directly and bypassing argument sanitation. - This is a faster, but less user friendly version of ``__call__()``. - While ``__call__()`` will transforms its Python arguments such that - they can be forwarded, this function assumes that this processing - was already done by the user. + This is a faster, but less user friendly version of ``__call__()``. While + ``__call__()`` will transforms its Python arguments such that they can be + forwarded and allocate memory for the return values, this function assumes + that this processing was already done by the user. + To build the argument vectors you should use `self.construct_arguments()`. :param callargs: Arguments passed to the actual computation. :param initargs: Arguments passed to the initialization function. :param do_gpu_check: Check if errors happened on the GPU. - :note: You may use `_construct_args()` to generate the processed arguments. + :note: This is an advanced interface. + :note: In previous versions this function also called `convert_return_values()`. """ try: # Call initializer function if necessary, then SDFG @@ -512,8 +563,7 @@ def fast_call( if lasterror is not None: raise RuntimeError( f'An error was detected when calling "{self._sdfg.name}": {self._get_error_text(lasterror)}') - - return self._convert_return_values() + return except (RuntimeError, TypeError, UnboundLocalError, KeyError, cgx.DuplicateDLLError, ReferenceError): self._lib.unload() raise @@ -525,18 +575,40 @@ def __del__(self): self._libhandle = ctypes.c_void_p(0) self._lib.unload() - def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: - """ - Main function that controls argument construction for calling - the C prototype of the SDFG. + def construct_arguments(self, *args: Any, **kwargs: Any) -> Tuple[Tuple[Any], Tuple[Any]]: + """Construct the argument vectors suitable for from its argument. - Organizes arguments first by ``sdfg.arglist``, then data descriptors - by alphabetical order, then symbols by alphabetical order. + The function returns a pair of tuple, that are suitable for `fast_call()`. + The first element of is `callargs`, i.e. the full arguments, while the + second element is `initargs`, which is only used/needed the first time + an SDFG is called. - :note: If not initialized this function will initialize the memory for - the return values, however, it might also reallocate said memory. - :note: This function will also update the internal argument cache. + It is important that this function will also allocate new return values. + The array objects are managed by `self` and remain valid until this + function is called again. However, they are also returned by `self.__call__()`. + + It is also possible to pass the array, that should be used to return a value, + directly as argument. In that case the allocation for that return value will + be skipped. + + :note: In case of arrays, the returned argument vectors only contains the + pointers to the underlying memory. Thus it is the user's responsibility + to ensure that the memory remains allocated until the argument vector + is used. + :note: This is an advanced interface. """ + if self.argnames is None and len(args) != 0: + raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.") + elif len(args) > 0 and self.argnames is not None: + positional_arguments = {aname: avalue for aname, avalue in zip(self.argnames, args)} + if not positional_arguments.keys().isdisjoint(kwargs.keys()): + raise ValueError( + f'The arguments where passed once as positional and named arguments: {set(positional_arguments.keys()).intersection(kwargs.keys())}' + ) + kwargs.update(positional_arguments) + + # NOTE: This might invalidate the elements associated to the return values of + # all argument vectors that were created before. self._initialize_return_values(kwargs) # Add the return values to the arguments, since they are part of the C signature. @@ -566,31 +638,51 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: argnames = [] sig = [] - # Type checking - cargs = [] no_view_arguments = not Config.get_bool('compiler', 'allow_view_arguments') - for i, (a, arg, atype) in enumerate(zip(argnames, arglist, argtypes)): - carg = dt.make_ctypes_argument(arg, - atype, - a, - allow_views=not no_view_arguments, - symbols=kwargs, - callback_retval_references=self._callback_retval_references) - cargs.append(carg) - - constants = self.sdfg.constants + cargs = tuple( + dt.make_ctypes_argument(aval, + atype, + aname, + allow_views=not no_view_arguments, + symbols=kwargs, + callback_retval_references=self._callback_retval_references) + for aval, atype, aname in zip(arglist, argtypes, argnames)) + symbols = self._free_symbols callparams = tuple((carg, aname) for arg, carg, aname in zip(arglist, cargs, argnames) - if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants))) - - newargs = tuple(carg for carg, aname in callparams) + if not ((hasattr(arg, 'name') and arg.name in self._constants) and symbolic.issymbolic(arg))) + newargs = tuple(carg for carg, _aname in callparams) initargs = tuple(carg for carg, aname in callparams if aname in symbols) - self._lastargs = newargs, initargs - return self._lastargs + return (newargs, initargs) + + def convert_return_values(self) -> Union[Any, Tuple[Any, ...]]: + """Convert the return arguments. + + Execute the `return` statement and return. This function should only be called + after `fast_call()` has been run. + Keep in mid that it is not possible to return scalars (with the exception of + `pyobject`s), they will be always returned as an array with shape `(1,)`. + + :note: This is an advanced interface. + :note: After `fast_call()` returns it is only allowed to call this function once. + """ + # TODO: Make sure that the function is called only once by checking it. + # NOTE: Currently it is not possible to return a scalar value, see `tests/sdfg/scalar_return.py` + if not self._return_arrays: + return None + elif self._is_single_value_ret: + assert len(self._return_arrays) == 1 + return self._return_arrays[0].item() if self._retarray_is_pyobject[0] else self._return_arrays[0] + else: + return tuple(r.item() if is_pyobj else r + for r, is_pyobj in zip(self._return_arrays, self._retarray_is_pyobject)) def clear_return_values(self): - self._create_new_arrays = True + warnings.warn( + 'The "CompiledSDFG.clear_return_values" API is deprecated, as this behaviour has' + ' become the new default, and is a noops.', DeprecationWarning) + pass def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int], strides: Tuple[int], total_size: int): @@ -626,52 +718,76 @@ def _initialize_return_values(self, kwargs): # Clear references from last call (allow garbage collection) self._callback_retval_references.clear() - if self._initialized: - if self._return_syms == syms: - if not self._create_new_arrays: - return - else: - self._create_new_arrays = False - # Use stored sizes to recreate arrays (fast path) - self._return_arrays = tuple(kwargs[desc[0]] if desc[0] in kwargs else self._create_array(*desc) - for desc in self._retarray_shapes) - return + if self._initialized and self._return_syms == syms: + # Use stored sizes to recreate arrays (fast path) + self._return_arrays = tuple(kwargs[desc[0]] if desc[0] in kwargs else self._create_array(*desc) + for desc in self._retarray_shapes) + return self._return_syms = syms - self._create_new_arrays = False - - # Initialize return values with numpy arrays - self._retarray_shapes = [] self._return_arrays = [] + self._retarray_shapes = [] + self._retarray_is_pyobject = [] for arrname, arr in sorted(self.sdfg.arrays.items()): - if arrname.startswith('__return') and not arr.transient: - if arrname in kwargs: + if arrname.startswith('__return'): + if arr.transient: + raise ValueError(f'Used the special array name "{arrname}" as transient.') + + elif arrname in kwargs: + # The return value is passed as an argument, in that case store the name in `self._retarray_shapes`. + warnings.warn(f'Return value "{arrname}" is passed as a regular argument.', stacklevel=2) self._return_arrays.append(kwargs[arrname]) - self._retarray_is_scalar.append(isinstance(arr, dt.Scalar)) self._retarray_shapes.append((arrname, )) - continue - if isinstance(arr, dt.Stream): + elif isinstance(arr, dt.Stream): raise NotImplementedError('Return streams are unsupported') - shape = tuple(symbolic.evaluate(s, syms) for s in arr.shape) - dtype = arr.dtype.as_numpy_dtype() - total_size = int(symbolic.evaluate(arr.total_size, syms)) - strides = tuple(symbolic.evaluate(s, syms) * arr.dtype.bytes for s in arr.strides) - shape_desc = (arrname, dtype, arr.storage, shape, strides, total_size) - self._retarray_is_scalar.append(isinstance(arr, dt.Scalar) or isinstance(arr.dtype, dtypes.pyobject)) - self._retarray_shapes.append(shape_desc) - - # Create an array with the properties of the SDFG array - arr = self._create_array(*shape_desc) - self._return_arrays.append(arr) + else: + shape = tuple(symbolic.evaluate(s, syms) for s in arr.shape) + dtype = arr.dtype.as_numpy_dtype() + total_size = int(symbolic.evaluate(arr.total_size, syms)) + strides = tuple(symbolic.evaluate(s, syms) * arr.dtype.bytes for s in arr.strides) + shape_desc = (arrname, dtype, arr.storage, shape, strides, total_size) + self._retarray_shapes.append(shape_desc) + + # Create an array with the properties of the SDFG array + return_array = self._create_array(*shape_desc) + self._return_arrays.append(return_array) + + # BUG COMPATIBILITY(PR#2206): + # In the original version `_retarray_is_pyobject` was named `_retarray_is_scalar`, however + # since scalars could not be returned on an [implementation level](https://github.com/spcl/dace/pull/1609) + # it was essentially useless. But was used for `pyobject` in _some_ cases. And indeed, + # since `pyobject`s are essentially `void` pointers is was, in principle possible, to return/pass + # them as "scalars", read "not inside an array". + # However, if the return value was passed as argument, i.e. the first `elif`, then it + # was ignored if `arr` was a `pyobject`. Only if the return value was managed by `self`, + # i.e. the `else` case, then it was considered, in a way at least. The problem was, that it was + # done using the following check: + # `isinstance(arr, dt.Scalar) or isinstance(arr.dtype, dtypes.pyobject)` + # Because of the `or` that is used, _everything_ whose `dtype` is `pyobject` was classified + # as a scalar `pyobject`, i.e. one element, even if it was in fact an array of millions of `pyobject`s. + # The correct behaviour would be to change the `or` to an `and` but then several unit + # tests (`test_pyobject_return`, `test_pyobject_return_tuple` and `test_nested_autoparse[False]` + # in `tests/python_frontend/callee_autodetect_test.py`) will fail. + # The following code is bug compatible and also allows to pass a `pyobject` directly, i.e. + # through `kwargs`. + if isinstance(arr.dtype, dtypes.pyobject): + if isinstance(arr, dt.Scalar): + # Proper scalar. + self._retarray_is_pyobject.append(True) + elif isinstance(arr, dt.Array): + # An array, let's check if it is just a wrapper for a single value. + if not (len(arr.shape) == 1 and arr.shape[0] == 1): + warnings.warn(f'Decay an array of `pyobject`s with shape {arr.shape} to a single one.', + stacklevel=2) + self._retarray_is_pyobject.append(True) + else: + raise ValueError( + f'Does not know how to handle "{arrname}", which is a {type(arr).__name__} of `pyobject`.') + else: + self._retarray_is_pyobject.append(False) - def _convert_return_values(self): - # Return the values as they would be from a Python function - # NOTE: Currently it is not possible to return a scalar value, see `tests/sdfg/scalar_return.py` - if not self._return_arrays: - return None - elif len(self._return_arrays) == 1: - return self._return_arrays[0].item() if self._retarray_is_scalar[0] else self._return_arrays[0] - else: - return tuple(r.item() if scalar else r for r, scalar in zip(self._return_arrays, self._retarray_is_scalar)) + assert (not self._is_single_value_ret) or (len(self._return_arrays) == 1) + assert len(self._return_arrays) == len(self._retarray_shapes) == len(self._retarray_is_pyobject) + self._return_arrays = tuple(self._return_arrays) diff --git a/tests/codegen/external_memory_test.py b/tests/codegen/external_memory_test.py index 169e050914..47eac55ff3 100644 --- a/tests/codegen/external_memory_test.py +++ b/tests/codegen/external_memory_test.py @@ -30,7 +30,7 @@ def tester(a: dace.float64[N]): a = np.random.rand(20) if symbolic: - extra_args = dict(a=a, N=20) + extra_args = dict(N=20) else: extra_args = {} diff --git a/tests/python_frontend/return_value_test.py b/tests/python_frontend/return_value_test.py index 4a845bea0b..4e704287bc 100644 --- a/tests/python_frontend/return_value_test.py +++ b/tests/python_frontend/return_value_test.py @@ -9,7 +9,15 @@ def test_return_scalar(): def return_scalar(): return 5 - assert return_scalar() == 5 + res = return_scalar() + assert res == 5 + + # Don't be fooled by the test above the return value is an array. If you would + # add the return value annotation to the program, i.e. `-> dace.int32` you would + # get a validation error. + assert isinstance(res, np.ndarray) + assert res.shape == (1, ) + assert res.dtype == np.int64 def test_return_scalar_in_nested_function(): @@ -22,7 +30,15 @@ def nested_function() -> dace.int32: def return_scalar(): return nested_function() - assert return_scalar() == 5 + res = return_scalar() + assert res == 5 + + # Don't be fooled by the test above the return value is an array. If you would + # add the return value annotation to the program, i.e. `-> dace.int32` you would + # get a validation error. + assert isinstance(res, np.ndarray) + assert res.shape == (1, ) + assert res.dtype == np.int32 def test_return_array(): @@ -42,6 +58,8 @@ def return_tuple(): return 5, 6 res = return_tuple() + assert isinstance(res, tuple) + assert len(res) == 2 assert res == (5, 6) @@ -52,6 +70,8 @@ def return_array_tuple(): return 5 * np.ones(5), 6 * np.ones(6) res = return_array_tuple() + assert isinstance(res, tuple) + assert len(res) == 2 assert np.allclose(res[0], 5 * np.ones(5)) assert np.allclose(res[1], 6 * np.ones(6)) @@ -66,10 +86,25 @@ def return_void(a: dace.float64[20]): a = np.random.rand(20) ref = a + 1 - return_void(a) + res = return_void(a) + assert res is None assert np.allclose(a, ref) +def test_return_tuple_1_element(): + + @dace.program + def return_one_element_tuple(a: dace.float64[20]): + return (a + 3.5, ) + + a = np.random.rand(20) + ref = a + 3.5 + res = return_one_element_tuple(a) + assert isinstance(res, tuple) + assert len(res) == 1 + assert np.allclose(res[0], ref) + + def test_return_void_in_if(): @dace.program From 781a8e800391ed756511124bdc39334d36344244 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:29:49 +0100 Subject: [PATCH 14/17] Squashed commit of the following: commit 5b068e775e622375d86c4e261ab911f71ef7b983 Author: Affifboudaoud Date: Sun Nov 23 22:50:46 2025 +0100 Add visited set to avoid visiting same node multiple times --- dace/transformation/dataflow/map_fusion_vertical.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fusion_vertical.py b/dace/transformation/dataflow/map_fusion_vertical.py index eda0e639f5..61a9eadc0f 100644 --- a/dace/transformation/dataflow/map_fusion_vertical.py +++ b/dace/transformation/dataflow/map_fusion_vertical.py @@ -1548,11 +1548,15 @@ def _is_data_accessed_downstream( def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: return (edge.dst for edge in graph.out_edges(node)) - # Dataflow graph is acyclic, so we do not need to keep a list of - # what we have visited. + # Track visited nodes to avoid exponential blowup from visiting + # the same node multiple times via different paths in the DAG. to_visit: List[nodes.Node] = list(next_nodes(begin)) + visited: Set[nodes.Node] = set() while len(to_visit) > 0: node = to_visit.pop() + if node in visited: + continue + visited.add(node) if isinstance(node, nodes.AccessNode) and node.data == data: return True to_visit.extend(next_nodes(node)) From 4e4f36fdbcdc7e5635b27bdc3522cdb498a6af57 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:30:02 +0100 Subject: [PATCH 15/17] Squashed commit of the following: commit 79941322eddb8f1018328cf60160f2d7fb82fff0 Merge: b8a0fd714 387f1e871 Author: Philipp Schaad Date: Fri Dec 12 09:25:10 2025 +0100 Merge branch 'main' into fix_block_size_config commit b8a0fd71444dac17e7b0c7c8bee7e9fd07ebe442 Author: Edoardo Paone Date: Thu Dec 11 14:14:34 2025 +0100 edit --- dace/codegen/targets/cuda.py | 4 ++-- dace/transformation/dataflow/add_threadblock_map.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 55e2fe3241..c98095d7e9 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -2148,8 +2148,8 @@ def get_kernel_dimensions(self, dfg_scope): # Check block size against configured maximum values, if those can be determined total_bsize = prod(block_size) - total_limit = Config.get('compiler', 'cuda', 'block_size_limit') - lastdim_limit = Config.get('compiler', 'cuda', 'block_size_lastdim_limit') + total_limit = int(Config.get('compiler', 'cuda', 'block_size_limit')) + lastdim_limit = int(Config.get('compiler', 'cuda', 'block_size_lastdim_limit')) if (total_bsize > total_limit) == True: raise ValueError(f'Block size for kernel "{kernelmap_entry.map.label}" ({block_size}) ' f'is larger than the possible number of threads per block ({total_limit}). ' diff --git a/dace/transformation/dataflow/add_threadblock_map.py b/dace/transformation/dataflow/add_threadblock_map.py index 9bc5a8a2a7..febdb12861 100644 --- a/dace/transformation/dataflow/add_threadblock_map.py +++ b/dace/transformation/dataflow/add_threadblock_map.py @@ -76,8 +76,8 @@ def validate_block_size_limits(kernel_map_entry: nodes.MapEntry, block_size: Lis kernel_map_label = kernel_map_entry.map.label total_block_size = product(block_size) - limit = Config.get('compiler', 'cuda', 'block_size_limit') - lastdim_limit = Config.get('compiler', 'cuda', 'block_size_lastdim_limit') + limit = int(Config.get('compiler', 'cuda', 'block_size_limit')) + lastdim_limit = int(Config.get('compiler', 'cuda', 'block_size_lastdim_limit')) if (total_block_size > limit) == True: raise ValueError(f'Block size for kernel "{kernel_map_label}" ({block_size}) ' From c18b015d18d89f0466a3d6345d317ff646d61b4b Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:30:17 +0100 Subject: [PATCH 16/17] Squashed commit of the following: commit c9f93fd0604a59a5ffaba54939891b886916ef00 Author: Edoardo Paone Date: Thu Dec 18 00:00:20 2025 +0100 fix state fusion for write-write hazard --- .../transformation/interstate/state_fusion.py | 4 +- tests/transformations/state_fusion_test.py | 46 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 121d74380d..c76cd289e0 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -155,7 +155,9 @@ def _check_paths(self, first_state: SDFGState, second_state: SDFGState, match_no continue path_found |= True node2 = next(n for n in second_input if n.data == match.data) - if not all(nx.has_path(second_state._nx, node2, n) for n in nodes_second): + if not all( + second_state.in_degree(n) == 1 and nx.has_path(second_state._nx, node2, n) + for n in nodes_second): fail = True break # We keep looking for a potential match with a path that fail to find diff --git a/tests/transformations/state_fusion_test.py b/tests/transformations/state_fusion_test.py index aece80f574..3fdd236303 100644 --- a/tests/transformations/state_fusion_test.py +++ b/tests/transformations/state_fusion_test.py @@ -185,6 +185,51 @@ def state_fusion_test(A: dace.int32[20, 20]): assert len(sdfg.nodes()) == 1 +def test_write_write_path_multiple_producers(): + """ + The check for Write-Write hazard, in state fusion, should only apply to single + path between first and second state. The second state, in this SDFG, contains + two edges writing to the node, where one edge partially overrides the data which + was written in the first state. This constitutes a real Write-Write hazard, thus + state fusion should not be applied. + """ + sdfg = dace.SDFG('state_fusion_test') + A, A_desc = sdfg.add_array('A', [10, 10], dace.int32) + t, _ = sdfg.add_temp_transient_like(A_desc) + s1 = sdfg.add_state() + s2 = sdfg.add_state_after(s1) + + a1_node = s1.add_access(A) + s1.add_mapped_tasklet("write1", + map_ranges={ + "i": "0:10", + "j": "9" + }, + code='out = -1', + inputs={}, + outputs={'out': dace.Memlet("A[i, j]")}, + output_nodes={a1_node}, + external_edges=True) + s1.add_nedge(a1_node, s1.add_access(t), sdfg.make_array_memlet(A)) + + a2_node = s2.add_access(A) + s2.add_nedge(s2.add_access(t), a2_node, dace.Memlet('A[0:10, 0:5]')) + s2.add_mapped_tasklet("write2", + map_ranges={ + "i": "0:10", + "j": "5:10" + }, + code='out = -2', + inputs={}, + outputs={'out': dace.Memlet("A[i, j]")}, + output_nodes={a2_node}, + external_edges=True) + + sdfg.validate() + sdfg.apply_transformations_repeated(StateFusion) + assert len(sdfg.nodes()) == 2 + + def test_write_write_no_overlap(): """ Two states where both write to different ranges of an array. @@ -502,6 +547,7 @@ def test_check_paths(): test_two_cc_fusion_separate() test_two_cc_fusion_together() test_write_write_path() + test_write_write_path_multiple_producers() test_write_write_no_overlap() test_read_write_no_overlap() test_array_in_middle_no_overlap() From cd52c4ba29e7afa3e116bc6c3e4824df6e18c13c Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 18 Dec 2025 09:32:23 +0100 Subject: [PATCH 17/17] Set version file. --- dace/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/version.py b/dace/version.py index 1f356cc57b..171e86eccb 100644 --- a/dace/version.py +++ b/dace/version.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '43!2025.12.18'