Skip to content

beautiful_mnist_torch uses torch.compile TinyJIT working with TINY_BACKEND=1, see test_compile.py #3

@ppaleja

Description

@ppaleja

Overview / goal

  • Provide a tinygrad-based compatibility shim so upstream hlb-CIFAR10 code can be executed without changing its imports or training loop, while using tinygrad as the backend (a drop-in "tiny torch" shim).
  • Priority: API compatibility first (smoke runs), performance and multi-GPU/TinyJit improvements later.

##Acceptance criteria (copy from issue)

  • Upstream hlb-CIFAR10's main script can be executed without modifying its source imports/logic and runs a short smoke training (e.g., STEPS=5) using tinygrad as the backend.
  • No runtime AttributeError / ImportError due to missing API surface.
  • Basic training loop completes and produces a numeric loss (no crash).

High-level design

  • Implement a small compatibility package (or use existing extra/torch_backend) that exposes the minimal torch-like API that hlb-CIFAR10 expects: top-level objects, nn submodule, optim module, dtypes, helpers, and any extras (e.g., OneCycleLR).
  • Prefer thin wrappers mapping upstream names/signatures to tinygrad implementations rather than reimplementing behaviors.
  • Start with single-device correctness and a smoke test. Expand the shim until all required API calls used by the upstream code run cleanly.

Minimal API checklist (observed from the example)

Top-level and global objects

  • Tensor, Device.DEFAULT, GlobalCounters, TinyJit, Variable (expose these at top-level).
    nn / model / datasets
  • nn.Conv2d, nn.Linear, nn.BatchNorm2d (or compatible replacement).
  • nn.datasets.cifar() loader (or an adapter that returns numpy arrays/Tensors as upstream expects).
    state / serialization
  • get_state_dict(model) -> mapping name -> parameter tensor.
    optim / scheduler
  • optim.SGD(params, lr, momentum, nesterov, weight_decay) and OptimizerGroup compatibility as used by upstream.
  • extra.lr_scheduler.OneCycleLR adapter (thin wrapper to the existing extra/lr_scheduler.py).
    dtypes / helpers
  • dtypes.default_float, float32, int, int32.
  • helpers.Context(...), BEAM, WINO, getenv(), colored(), prod().
    Tensor methods / operations used by hlb-CIFAR10
  • Randoms: randperm, randint, rand
  • Creation: arange, ones, zeros
  • Access & transform: reshape, cast, float, numpy, astype, one_hot, gather, where, flip, contiguous, expand, reshape, cat, pad, sequential
  • NN ops: conv2d, max_pool2d, batchnorm, quick_gelu, log_softmax, argmax
  • Reductions / math: mean, sum, mul, div, add, pow, rsqrt
  • Autograd: detach, backward, realize, assign
  • Device / shape: .device attribute, .shape attribute
    Control
  • Tensor.train() context manager and Tensor.training flag
    Extras (optional)
  • extra/bench_log and other small utilities for timing (optional but helpful in smoke tests)

Files to inspect first (high-value)

  • tinygrad/examples/hlb_cifar10.py — canonical usage; extract exact method signatures and everything it imports (you already have this list).
    • For quick reference: tinygrad/examples/hlb_cifar10.py
  • extra/lr_scheduler.py — OneCycleLR expected API.
  • tinygrad/nn/state.pyget_state_dict implementation.
  • tinygrad/nn/optim* — check SGD behavior and parameter group handling.
  • tinygrad/helpers.pyContext, getenv, BEAM, WINO, colored, prod.
  • tinygrad/nn/datasetscifar() loader.
  • extra/torch_backend/* — examine backend.py, wrapped_tensor.cpp, test_compile.py if you plan to reuse the torch shim approach.

Suggested implementation steps (ordered, with a recommended iteration plan)

  1. Prepare a reproducible smoke-run script and environment

    • Make a small script reproducing upstream invocation but with STEPS=5, BS=... etc.
    • Use PYTHONPATH or a small shim package to ensure the upstream code imports from your compatibility layer.
    • Example run (single-device smoke):
      • TINY_BACKEND=1 STEPS=5 BS=32 python3 examples/hlb_cifar10.py
    • Goal: get through an entire training loop for STEPS=5 and get a numeric loss with no exceptions.
  2. Start with an adapter / compatibility module

    • Implement a minimal tinytorch or torch_compat module (could be placed under extra/torch_compat or similar) that re-exports tinygrad symbols under the names expected by upstream. Example exports:
      • from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variable
      • from tinygrad import nn as nn (and adapt or wrap components when signatures differ)
    • Provide thin wrappers for:
      • optim.SGD signature differences.
      • nn.BatchNorm2d → map to UnsyncedBatchNorm used in hlb_cifar10.py if needed.
      • OneCycleLR — adapter to extra/lr_scheduler.OneCycleLR with matching API.
  3. Implement missing Tensor methods (only as needed)

    • Run the smoke test and see AttributeError traces.
    • For each missing method used in hlb_cifar10.py, implement either in tinygrad code or as a shim function (wrapper returning compatible result).
    • Prefer implementing missing helpers in the compatibility module if the functionality is simple (e.g., signature adaptation), and implement true ops in tinygrad Tensor if they are fundamental (e.g., gather, batchnorm if missing).
  4. Ensure correct parameter naming / get_state_dict

    • Confirm get_state_dict returns the param mapping names expected by optimizer / EMA code.
    • If upstream expects param names in a particular format (e.g., module.layer.weight), ensure the tinygrad state exports consistent names.
  5. Add a smoke test harness and automated run

    • Add a small test script in extra/compat_tests/smoke_hlb_cifar10.py that invokes train_cifar() with small STEPS and asserts no crash and numeric loss (e.g., loss is finite).
    • Make test deterministic by seeding (Tensor.manual_seed and random.seed).
  6. Iterate until the smoke test passes

    • Each failure should produce a short PR that either:
      • Exposes missing API in the compatibility shim, or
      • Fixes tinygrad internals (e.g., add gather or broadcasting behavior).
    • Keep changes small and test-driven.
  7. Optional: upstream drop-in behavior

    • Option A (recommended for speed): Add a small import shim so that when the upstream project imports torch you can easily configure env var to prefer tiny (requires using extra/torch_backend built-in integration).
    • Option B: Provide a tinytorch package and instruct users to run upstream code with PYTHONPATH=extra/torch_compat (this is safest and non-invasive).

Concrete debug & repro tips

  • Use small batch sizes (BS) and STEPS=3-5 to get fast feedback.
  • Use environment variables for toggles and seeds:
    • Tensor.manual_seed(seed) and random.seed(seed) are used in hlb_cifar10.py.
  • Run with verbose exceptions and TORCH_DEBUG if you hit torch/backend interactions.
  • If you hit numpy / dtype mismatches, ensure conversions are explicit using .astype(np.float32) or Tensor(...).cast(...).
  • To pinpoint missing API: run the smoke harness, reproduce the AttributeError and trace to the first missing symbol. Implement the smallest shim that satisfies that call.

Integration points already present in repo (useful to reuse)

  • extra/torch_backend/backend.py — a fairly complete PyTorch compatibility backend you can reuse or take inspiration from.
  • extra/torch_backend/wrapped_tensor.cpp — a compiled extension used to create opaque PyTorch tensors that refer to tinygrad Tensor. Useful if you want to support drop-in torch semantics at the C++/PyTorch level.
  • tinygrad/examples/beautiful_mnist_torch.py and extra/torch_backend/test_compile.py — examples showing how to use TINY_BACKEND and (experimentally) torch.compile with tinygrad.

Smoke test checklist (what to run locally)

  • Ensure minimal dependencies are installed (numpy, PIL if needed).
  • Run:
    • STEPS=5 BS=32 GPUS=1 python3 tinygrad/examples/hlb_cifar10.py
    • If using the torch shim / privateuse backend: TINY_BACKEND=1 STEPS=5 BS=32 python3 tinygrad/examples/hlb_cifar10.py
  • Expected: no AttributeError/ImportError; training loop completes STEPS iterations, returns numeric loss.

CI proposal

  • Add a CI job that runs the smoke test in a matrix:
    • env: STEPS=5 BS=32 (short), GPUS=1
    • install minimal dependencies
    • run smoke harness
  • Mark the job optional/allowed to fail until you trust the shim.
  • For the C++ extension (wrapped_tensor.cpp) compilation in CI, ensure build tools are available or guard the job to skip compilation if toolchain missing.

Prioritized backlog (small PR-sized tasks)

  1. Create extra/torch_compat (or extra/tinytorch) that re-exports and adapts tinygrad symbols.
  2. Add smoke test script and run locally; log the initial failures.
  3. Implement shims for missing top-level functions and small Tensor methods invoked early (e.g., gather, pad, sequential).
  4. Fix get_state_dict parameter format if required.
  5. Adapter for optim.SGD parameter groups matching upstream behavior.
  6. Adapter for OneCycleLR usage patterns.
  7. Optional: make extra/torch_backend drop-in when TINY_BACKEND=1 (reuse existing code).
  8. Add CI smoke test.

PR / review checklist

  • Minimal, focused changes: each PR should fix a single missing function or small compatibility mismatch.
  • Unit / smoke tests added or updated.
  • No whitespace-only changes mixed with logic changes (project style).
  • Keep lines < 150 chars and 2-space indentation in code changes.
  • Document any API incompatibilities that remain.

Risks / known complications

  • Some ops (e.g., batchnorm internals, efficient conv2d) are sensitive and may require careful tinygrad implementations (for correctness).
  • Multi-GPU / sharding, TinyJit interactions, or torch.compile may introduce complex conversion issues — postpone until basic API compatibility is stable.
  • C++ extension (wrapped_tensor.cpp) requires a working build toolchain — this can make CI more complex. Consider initially avoiding requiring the compiled extension.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions