forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
Problem summary
- The
extra/torch_backendshim contains many workarounds that materialize or force contiguity instead of correctly supporting arbitrarystrides/as_stridedsemantics. Those hacks show up in places likeempty_strided,_copy_from,_as_stridedand elsewhere. - The movement-op layer (
extra/to_movement_ops.py) currently treats stride changes very restrictively (it only treats -1/1 flips asSTRIDE), which prevents correct reconstruction of arbitrary views. - Result: certain PyTorch programs (especially ones relying on non-contiguous layouts,
as_strided, or specific stride-based views) either get incorrect results or require expensive/incorrect copies.
Top-level goal (north star)
- Make the
tinyPyTorch backend fully correct for PyTorch semantics regarding views and strides so that:aten::as_strided,aten::empty_strided, and related view APIs are supported without forcing contiguity or doing ad-hoc copies.- Copies and assignments preserve requested shape + stride semantics where possible.
- The backend’s behavior matches PyTorch for shape/stride/offset semantics (including negative strides) with well-defined performance characteristics (prefer lazy, minimal-realize paths).
Primary acceptance criteria
- Correctness:
aten::as_stridedandaten::empty_stridedreturn tensors with the exact requested shape/stride/storage_offset semantics (not just a contiguous copy unless that is the only valid representation).- Copies between tensors with different stride patterns produce identical results to PyTorch (including negative strides).
- No new regressions for existing PyTorch backend tests (
extra/torch_backend/torch_tests.pyshould pass or show reduced failures).
- Tests:
- Add unit tests exercising positive/negative/non-unit strides and
empty_strided. - Add end-to-end smoke test: run a small upstream script (e.g., hlb-CIFAR10 steps=5) with the
tinytorch shim and demonstrate a numeric loss (no crash).
- Add unit tests exercising positive/negative/non-unit strides and
- Maintainability:
- Remove ad-hoc comments like "this only solves some cases" where a proper implementation exists.
- Document any remaining unavoidable materializations and why they’re necessary.
Key files to inspect and why
extra/torch_backend/backend.py- Central PyTorch compatibility layer. Look for:
aten::as_strided/_as_stridedaten::empty_strided/empty.memory_format_copy_from,_reshape_alias, uses of.contiguous()/.realize()as workarounds
- These are where the current hacks and TODOs are concentrated.
- Central PyTorch compatibility layer. Look for:
extra/to_movement_ops.py- Converts
ShapeTrackerviews into a sequence of MovementOps (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, STRIDE, AS_STRIDED). - Current
STRIDEhandling asserts only [-1, 1] values and maps to axis flip. This is a key limitation to fix.
- Converts
extra/torch_backend/wrapped_tensor.cpp- C++ glue used to map tinygrad
Tensorto PyTorchtorch.Tensor. Must ensure stride + offset metadata is passed correctly.
- C++ glue used to map tinygrad
extra/torch_backend/torch_tests.py- Test harness for PyTorch ops against the
tinybackend. Use this to validate behavior.
- Test harness for PyTorch ops against the
tinygrad/tinygrad/tensor.py- Core semantics for views,
contiguous(),realize(),assign(), and other methods the backend relies on.
- Core semantics for views,
- Any ShapeTracker / View API that the movement layer expects
- Many files import
tinygrad.shape.shapetrackerandtinygrad.shape.view(seeextra/to_movement_ops.py), so ensure you find the canonical implementation in the tree (the codebase may place it undertinygrad/shapeortinygrad/shape/*).
- Many files import
Key technical changes to enable proper strides
- Represent arbitrary stride patterns in the view machinery
- Ensure
View/ShapeTrackercan represent:- per-dimension stride (positive/negative integer)
- storage offset (index into flat buffer)
- masks (for shrinks/pads) and contiguous flag
- If representation already exists, confirm api and tests; if gaps exist, extend it.
- Ensure
- Movement op generation / application
- Update
extra/to_movement_ops.py:- Remove the assert that restricts
STRIDEto [-1,1]. - Either:
- Implement
STRIDEsemantics that can express arbitrary stride steps (e.g., step > 1) and negative strides, or - Prefer
AS_STRIDEDsemantics (explicitly generate anAS_STRIDEDop carrying the stride tuple + offset) when arbitrary stride is needed.
- Implement
- Ensure
apply_mopcan apply these ops to aShapeTrackerand that the resulting scratch shape/buffer size calculation is correct (get_buffer_sizeandget_real_viewhelpers are relevant).
- Remove the assert that restricts
- Keep
to_movement_ops's correctness checks (e.g.,test_rebuild) and extend them to validate arbitrary stride cases.
- Update
- Allocation for strided buffers
- Implement
aten::empty_stridedto allocate a base buffer of the correct minimal size (viaget_buffer_size) and return a view with requestedshape,stride, andstorage_offset(not a contiguous.contiguous()tensor). - Confirm APIs to construct a
Viewwith explicit stride and offset exist (e.g.,View.create(...)used already inbackend.py).
- Implement
- Copy / assign semantics
- Improve
_copy_fromandassignpaths:- If dest layout matches requested stride exactly, perform assign without extra copies.
- If not, produce an intermediate view that maps source into target layout without forcing full realization or, if necessary, create a temporary buffer that has the exact layout (and limit such cases).
- Avoid blind
src = src.contiguous()fixes in_copy_from. Instead, add code to transform layout correctly using movement ops if possible.
- Improve
- C++ wrapper metadata
- Update
extra/torch_backend/wrapped_tensor.cppto passstridesandstorage_offsetmetadata into PyTorch wrapper tensors so PyTorch clients see correct shapes & strides. - Ensure
unwrap/wrappreserve stride + offset semantics and device mapping.
- Update
Testing strategy (unit + integration)
- Unit tests to add:
test_as_strided_basic: create a base tensor and useas_stridedto create view with arbitrary stride (positive, negative, step>1) and assert read/write semantics match PyTorch.test_empty_strided_alloc:empty_stridedreturns a tensor with correctshapeandstrideandnumelcorresponds to correct buffer size; writing then reading yields expected layout.test_copy_between_strides: copy between tensors with different strides (source contiguous, dest non-contig; source non-contig, dest contiguous; both non-contig with different patterns). Validate values match PyTorch.test_negative_stride: slicing with negative steps (e.g.,tensor.flip(dim)) andas_stridednegative strides should behave like PyTorch.- Add tests to
extra/to_movement_ops.pyto validateto_movement_opsandapply_mopwith arbitrary stride/view combinations (coverSTRIDEand/orAS_STRIDEDcases).
- Integration smoke:
- Run
extra/torch_backend/torch_tests.pyto exercise many aten ops withtinybackend. - Run
examples/hlb_cifar10.py(or the upstream hlb-CIFAR10 main script) with `STE
- Run
PS=5using thetiny` torch shim to verify training proceeds end-to-end.
- Continuous validation:
- Add CI job (optional first, then required) that runs the smoke test on PRs touching backend/shape code.
Concrete step-by-step plan / milestones
- Investigation (0.5–1 day)
- Reproduce failing cases: run
extra/torch_backend/torch_tests.pyand any existing failing tests that hint at stride issues. - Find exact places in
backend.pythat currently force.contiguous()or otherwise bypass stride semantics. - Confirm where
ViewandShapeTrackerimplementations live and their API (if not located, search the repo forclass View/ShapeTracker).
- Reproduce failing cases: run
- Movement ops fix (1–2 days)
- Remove the
STRIDErestriction; implement a richerSTRIDEorAS_STRIDEDhandling inapply_mop. - Add unit tests for
to_movement_opswith non-trivial stride patterns.
- Remove the
empty_stridedandas_strided(1 day)- Implement
aten::empty_stridedto allocate correct buffer size and return a proper view (no forced contiguous). - Rework
_as_strided/_reshape_aliasto use the improved movement op path instead of heavy weight or incorrect fallbacks.
- Implement
- Copy/assign semantics (1–2 days)
- Rework
_copy_fromto avoid blindcontiguous()when possible; perform minimal necessary transformation or create layout-matching temporary buffer.
- Rework
- C++ wrapper updates (0.5–1 day)
- Ensure
wrapped_tensor.cppexposes the stride + offset metadata properly.
- Ensure
- Testing & cleanup (1–2 days)
- Add tests listed above.
- Run
extra/torch_backend/torch_tests.pyand fix any regressions. - Add the hlb-CIFAR10 smoke run and iterate until it succeeds on small steps.
- PR: include tests, documentation, and a short migration note.
- Title: "torch backend: proper support for arbitrary strides and as_strided"
- Include before/after test logs for
torch_tests.pyand smoke test.
Helpful immediate edits you can make now (quick wins)
- Replace comments that call out hacks with
TODO+ link to this document and a short note on expected semantics. This frames future work and prevents accidental re-hacks. - Add targeted unit test(s) for
as_stridedthat currently fail — these will guide the implementation and act as regression tests.
Risk analysis & edge cases
- Performance: correct stride support might require extra kernel logic or buffer rearrangement depending on the computation backend. Prioritize correctness first; later optimize to avoid copies where possible.
- Multi-device / sharded tensors: if
to()/shard()paths move data across devices, stride semantics must be preserved or explicitly documented as normalized (e.g. contiguous) on transfer. Start with single-device correctness. - ShapeTracker / symbolic dims: if shapes are symbolic (Variables), movement-op generation must still produce correct ops or defer realization. Ensure
to_movement_opshandles symbolic cases robustly (it already contains symbolic checks — extend if needed). - Some code paths previously relied on "realize()" side-effects. Replacing hacks may expose latent bugs requiring careful testing.
PR checklist (what a completed PR should include)
- Implementation changes in
extra/to_movement_ops.pyandextra/torch_backend/backend.py. - Any necessary updates to
extra/torch_backend/wrapped_tensor.cpp. - New unit tests for
as_strided,empty_strided, negative strides, and copying between varied layouts. - Updated or added comments documenting the design and any trade-offs.
- No unrelated whitespace or style changes (follow tinygrad style rules).
- Run
extra/torch_backend/torch_tests.pyand include test results in PR description (or CI logs).
Suggested branch and PR title
- Branch name:
fix/torch-backend-strides - PR title:
torch backend: support arbitrary strides / as_strided without hacks
References (local files to start with)
extra/torch_backend/backend.pyextra/to_movement_ops.pyextra/torch_backend/wrapped_tensor.cppextra/torch_backend/torch_tests.pytinygrad/tinygrad/tensor.py
Metadata
Metadata
Assignees
Labels
No labels