Skip to content

Fix the PyTorch backend without hacks for strides #2

@ppaleja

Description

@ppaleja

Problem summary

  • The extra/torch_backend shim contains many workarounds that materialize or force contiguity instead of correctly supporting arbitrary strides/as_strided semantics. Those hacks show up in places like empty_strided, _copy_from, _as_strided and elsewhere.
  • The movement-op layer (extra/to_movement_ops.py) currently treats stride changes very restrictively (it only treats -1/1 flips as STRIDE), which prevents correct reconstruction of arbitrary views.
  • Result: certain PyTorch programs (especially ones relying on non-contiguous layouts, as_strided, or specific stride-based views) either get incorrect results or require expensive/incorrect copies.

Top-level goal (north star)

  • Make the tiny PyTorch backend fully correct for PyTorch semantics regarding views and strides so that:
    • aten::as_strided, aten::empty_strided, and related view APIs are supported without forcing contiguity or doing ad-hoc copies.
    • Copies and assignments preserve requested shape + stride semantics where possible.
    • The backend’s behavior matches PyTorch for shape/stride/offset semantics (including negative strides) with well-defined performance characteristics (prefer lazy, minimal-realize paths).

Primary acceptance criteria

  • Correctness:
    • aten::as_strided and aten::empty_strided return tensors with the exact requested shape/stride/storage_offset semantics (not just a contiguous copy unless that is the only valid representation).
    • Copies between tensors with different stride patterns produce identical results to PyTorch (including negative strides).
    • No new regressions for existing PyTorch backend tests (extra/torch_backend/torch_tests.py should pass or show reduced failures).
  • Tests:
    • Add unit tests exercising positive/negative/non-unit strides and empty_strided.
    • Add end-to-end smoke test: run a small upstream script (e.g., hlb-CIFAR10 steps=5) with the tiny torch shim and demonstrate a numeric loss (no crash).
  • Maintainability:
    • Remove ad-hoc comments like "this only solves some cases" where a proper implementation exists.
    • Document any remaining unavoidable materializations and why they’re necessary.

Key files to inspect and why

  • extra/torch_backend/backend.py
    • Central PyTorch compatibility layer. Look for:
      • aten::as_strided / _as_strided
      • aten::empty_strided / empty.memory_format
      • _copy_from, _reshape_alias, uses of .contiguous() / .realize() as workarounds
    • These are where the current hacks and TODOs are concentrated.
  • extra/to_movement_ops.py
    • Converts ShapeTracker views into a sequence of MovementOps (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, STRIDE, AS_STRIDED).
    • Current STRIDE handling asserts only [-1, 1] values and maps to axis flip. This is a key limitation to fix.
  • extra/torch_backend/wrapped_tensor.cpp
    • C++ glue used to map tinygrad Tensor to PyTorch torch.Tensor. Must ensure stride + offset metadata is passed correctly.
  • extra/torch_backend/torch_tests.py
    • Test harness for PyTorch ops against the tiny backend. Use this to validate behavior.
  • tinygrad/tinygrad/tensor.py
    • Core semantics for views, contiguous(), realize(), assign(), and other methods the backend relies on.
  • Any ShapeTracker / View API that the movement layer expects
    • Many files import tinygrad.shape.shapetracker and tinygrad.shape.view (see extra/to_movement_ops.py), so ensure you find the canonical implementation in the tree (the codebase may place it under tinygrad/shape or tinygrad/shape/*).

Key technical changes to enable proper strides

  1. Represent arbitrary stride patterns in the view machinery
    • Ensure View/ShapeTracker can represent:
      • per-dimension stride (positive/negative integer)
      • storage offset (index into flat buffer)
      • masks (for shrinks/pads) and contiguous flag
    • If representation already exists, confirm api and tests; if gaps exist, extend it.
  2. Movement op generation / application
    • Update extra/to_movement_ops.py:
      • Remove the assert that restricts STRIDE to [-1,1].
      • Either:
        • Implement STRIDE semantics that can express arbitrary stride steps (e.g., step > 1) and negative strides, or
        • Prefer AS_STRIDED semantics (explicitly generate an AS_STRIDED op carrying the stride tuple + offset) when arbitrary stride is needed.
      • Ensure apply_mop can apply these ops to a ShapeTracker and that the resulting scratch shape/buffer size calculation is correct (get_buffer_size and get_real_view helpers are relevant).
    • Keep to_movement_ops's correctness checks (e.g., test_rebuild) and extend them to validate arbitrary stride cases.
  3. Allocation for strided buffers
    • Implement aten::empty_strided to allocate a base buffer of the correct minimal size (via get_buffer_size) and return a view with requested shape, stride, and storage_offset (not a contiguous .contiguous() tensor).
    • Confirm APIs to construct a View with explicit stride and offset exist (e.g., View.create(...) used already in backend.py).
  4. Copy / assign semantics
    • Improve _copy_from and assign paths:
      • If dest layout matches requested stride exactly, perform assign without extra copies.
      • If not, produce an intermediate view that maps source into target layout without forcing full realization or, if necessary, create a temporary buffer that has the exact layout (and limit such cases).
    • Avoid blind src = src.contiguous() fixes in _copy_from. Instead, add code to transform layout correctly using movement ops if possible.
  5. C++ wrapper metadata
    • Update extra/torch_backend/wrapped_tensor.cpp to pass strides and storage_offset metadata into PyTorch wrapper tensors so PyTorch clients see correct shapes & strides.
    • Ensure unwrap/wrap preserve stride + offset semantics and device mapping.

Testing strategy (unit + integration)

  • Unit tests to add:
    • test_as_strided_basic: create a base tensor and use as_strided to create view with arbitrary stride (positive, negative, step>1) and assert read/write semantics match PyTorch.
    • test_empty_strided_alloc: empty_strided returns a tensor with correct shape and stride and numel corresponds to correct buffer size; writing then reading yields expected layout.
    • test_copy_between_strides: copy between tensors with different strides (source contiguous, dest non-contig; source non-contig, dest contiguous; both non-contig with different patterns). Validate values match PyTorch.
    • test_negative_stride: slicing with negative steps (e.g., tensor.flip(dim)) and as_strided negative strides should behave like PyTorch.
    • Add tests to extra/to_movement_ops.py to validate to_movement_ops and apply_mop with arbitrary stride/view combinations (cover STRIDE and/or AS_STRIDED cases).
  • Integration smoke:
    • Run extra/torch_backend/torch_tests.py to exercise many aten ops with tiny backend.
    • Run examples/hlb_cifar10.py (or the upstream hlb-CIFAR10 main script) with `STE

PS=5using thetiny` torch shim to verify training proceeds end-to-end.

  • Continuous validation:
    • Add CI job (optional first, then required) that runs the smoke test on PRs touching backend/shape code.

Concrete step-by-step plan / milestones

  1. Investigation (0.5–1 day)
    • Reproduce failing cases: run extra/torch_backend/torch_tests.py and any existing failing tests that hint at stride issues.
    • Find exact places in backend.py that currently force .contiguous() or otherwise bypass stride semantics.
    • Confirm where View and ShapeTracker implementations live and their API (if not located, search the repo for class View / ShapeTracker).
  2. Movement ops fix (1–2 days)
    • Remove the STRIDE restriction; implement a richer STRIDE or AS_STRIDED handling in apply_mop.
    • Add unit tests for to_movement_ops with non-trivial stride patterns.
  3. empty_strided and as_strided (1 day)
    • Implement aten::empty_strided to allocate correct buffer size and return a proper view (no forced contiguous).
    • Rework _as_strided / _reshape_alias to use the improved movement op path instead of heavy weight or incorrect fallbacks.
  4. Copy/assign semantics (1–2 days)
    • Rework _copy_from to avoid blind contiguous() when possible; perform minimal necessary transformation or create layout-matching temporary buffer.
  5. C++ wrapper updates (0.5–1 day)
    • Ensure wrapped_tensor.cpp exposes the stride + offset metadata properly.
  6. Testing & cleanup (1–2 days)
    • Add tests listed above.
    • Run extra/torch_backend/torch_tests.py and fix any regressions.
    • Add the hlb-CIFAR10 smoke run and iterate until it succeeds on small steps.
  7. PR: include tests, documentation, and a short migration note.
    • Title: "torch backend: proper support for arbitrary strides and as_strided"
    • Include before/after test logs for torch_tests.py and smoke test.

Helpful immediate edits you can make now (quick wins)

  • Replace comments that call out hacks with TODO + link to this document and a short note on expected semantics. This frames future work and prevents accidental re-hacks.
  • Add targeted unit test(s) for as_strided that currently fail — these will guide the implementation and act as regression tests.

Risk analysis & edge cases

  • Performance: correct stride support might require extra kernel logic or buffer rearrangement depending on the computation backend. Prioritize correctness first; later optimize to avoid copies where possible.
  • Multi-device / sharded tensors: if to()/shard() paths move data across devices, stride semantics must be preserved or explicitly documented as normalized (e.g. contiguous) on transfer. Start with single-device correctness.
  • ShapeTracker / symbolic dims: if shapes are symbolic (Variables), movement-op generation must still produce correct ops or defer realization. Ensure to_movement_ops handles symbolic cases robustly (it already contains symbolic checks — extend if needed).
  • Some code paths previously relied on "realize()" side-effects. Replacing hacks may expose latent bugs requiring careful testing.

PR checklist (what a completed PR should include)

  • Implementation changes in extra/to_movement_ops.py and extra/torch_backend/backend.py.
  • Any necessary updates to extra/torch_backend/wrapped_tensor.cpp.
  • New unit tests for as_strided, empty_strided, negative strides, and copying between varied layouts.
  • Updated or added comments documenting the design and any trade-offs.
  • No unrelated whitespace or style changes (follow tinygrad style rules).
  • Run extra/torch_backend/torch_tests.py and include test results in PR description (or CI logs).

Suggested branch and PR title

  • Branch name: fix/torch-backend-strides
  • PR title: torch backend: support arbitrary strides / as_strided without hacks

References (local files to start with)

  • extra/torch_backend/backend.py
  • extra/to_movement_ops.py
  • extra/torch_backend/wrapped_tensor.cpp
  • extra/torch_backend/torch_tests.py
  • tinygrad/tinygrad/tensor.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions