Skip to content

hlb-CIFAR10 works with tiny torch backend without code changes to the base repo #1

@ppaleja

Description

@ppaleja

Make hlb-CIFAR10 work with a drop-in "tiny torch" backend

Summary

  • Provide a tinygrad-based compatibility layer so the upstream hlb-CIFAR10 code can run unchanged while using tinygrad as the backend (i.e., a drop-in "tiny torch" shim).

Acceptance criteria

  • Upstream hlb-CIFAR10's main script can be executed without modifying its source imports/logic and runs a short smoke training (e.g., STEPS=5) using tinygrad as the backend.
  • No runtime AttributeError/ImportError due to missing API surface.
  • Basic training loop completes and produces a numeric loss (no crash).

Minimal API checklist (observed from examples/hlb_cifar10.py)

  • Top-level objects: Tensor, Device.DEFAULT, GlobalCounters, TinyJit, Variable
  • nn: Conv2d, Linear, BatchNorm2d (or compatible), nn.datasets.cifar()
  • state: get_state_dict(model) -> name -> parameter tensor mapping
  • optim: SGD(params, lr, momentum, nesterov, weight_decay), OptimizerGroup
  • dtypes: default_float, float32, int, int32
  • helpers: Context(...), BEAM, WINO, getenv(), colored(), prod()
  • Tensor methods/ops: randperm, randint, rand, arange, ones, zeros, reshape, cast, float, numpy, astype, one_hot, gather, where, flip, contiguous, expand, reshape, cat, pad, conv2d, max_pool2d, batchnorm, max, quick_gelu, log_softmax, argmax, mean, sum, mul, div, add, pow, rsqrt, detach, backward, realize, assign, to_, shard_, .device attribute, .shape attribute
  • Control: Tensor.train() context manager and Tensor.training flag
  • Utilities used from extras: extra/lr_scheduler.OneCycleLR and extra/bench_log (for timing; optional for compatibility)

Suggested implementation steps

  1. Create a small compatibility package (e.g., tinytorch or torch_compat) that maps the minimal torch-like symbols used by hlb-CIFAR10 to tinygrad equivalents.
    • Export names expected by upstream imports (Tensor, nn, optim, dtypes, Device, etc.).
    • Provide thin wrappers for any mismatched signatures (e.g., SGD, OneCycleLR adapter).
  2. Implement any missing Tensor methods or minimal shims that examples/hlb_cifar10.py relies on (or adapt mapping in the shim).
  3. Add a smoke test: run upstream hlb-CIFAR10 main.py with PYTHONPATH pointing to the shim and tinygrad, STEPS=5, and confirm it runs end-to-end.
  4. Iterate: fix missing behaviors discovered by the smoke test and expand coverage until the acceptance criteria are met.
  5. Optionally: add CI job that runs the smoke test on push.

Files to inspect first

  • examples/hlb_cifar10.py (reference: exact API usage and behavior)
  • extra/lr_scheduler.py (OneCycleLR expected API)
  • tinygrad/nn/state.py (get_state_dict implementation)
  • tinygrad/nn/optim* (SGD behavior)
  • tinygrad/helpers.py (Context, getenv, BEAM, WINO, colored)
  • tinygrad/nn/datasets module (cifar loader)
  • any Tensor implementation files for methods like gather, batchnorm, conv2d, pad, shard_/to_

Risks / notes

  • Some ops are performance-sensitive or backend-specific (sharding, multi-GPU strings, TinyJit, Context flags). For a first pass, focus on API correctness and smoke-running with single device.
  • whitening() in the example uses NumPy operations to precompute weights; ensure dtype/shape conversions are compatible.
  • EMA and BatchNorm running statistics touch internal parameter naming; get_state_dict must expose consistent names.

Owner / contact

  • Owner: @ppaleja (current user)
  • Suggested reviewers: maintainers familiar with tinygrad Tensor API and optimizers.

Estimated effort

  • Prototype shim + smoke test: 1–2 days
  • Full compatibility for multi-GPU/TinyJit polishing: additional days depending on gaps

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions