forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Make hlb-CIFAR10 work with a drop-in "tiny torch" backend
Summary
- Provide a tinygrad-based compatibility layer so the upstream hlb-CIFAR10 code can run unchanged while using tinygrad as the backend (i.e., a drop-in "tiny torch" shim).
Acceptance criteria
- Upstream hlb-CIFAR10's main script can be executed without modifying its source imports/logic and runs a short smoke training (e.g., STEPS=5) using tinygrad as the backend.
- No runtime AttributeError/ImportError due to missing API surface.
- Basic training loop completes and produces a numeric loss (no crash).
Minimal API checklist (observed from examples/hlb_cifar10.py)
- Top-level objects: Tensor, Device.DEFAULT, GlobalCounters, TinyJit, Variable
- nn: Conv2d, Linear, BatchNorm2d (or compatible), nn.datasets.cifar()
- state: get_state_dict(model) -> name -> parameter tensor mapping
- optim: SGD(params, lr, momentum, nesterov, weight_decay), OptimizerGroup
- dtypes: default_float, float32, int, int32
- helpers: Context(...), BEAM, WINO, getenv(), colored(), prod()
- Tensor methods/ops: randperm, randint, rand, arange, ones, zeros, reshape, cast, float, numpy, astype, one_hot, gather, where, flip, contiguous, expand, reshape, cat, pad, conv2d, max_pool2d, batchnorm, max, quick_gelu, log_softmax, argmax, mean, sum, mul, div, add, pow, rsqrt, detach, backward, realize, assign, to_, shard_, .device attribute, .shape attribute
- Control: Tensor.train() context manager and Tensor.training flag
- Utilities used from extras: extra/lr_scheduler.OneCycleLR and extra/bench_log (for timing; optional for compatibility)
Suggested implementation steps
- Create a small compatibility package (e.g., tinytorch or torch_compat) that maps the minimal torch-like symbols used by hlb-CIFAR10 to tinygrad equivalents.
- Export names expected by upstream imports (Tensor, nn, optim, dtypes, Device, etc.).
- Provide thin wrappers for any mismatched signatures (e.g., SGD, OneCycleLR adapter).
- Implement any missing Tensor methods or minimal shims that examples/hlb_cifar10.py relies on (or adapt mapping in the shim).
- Add a smoke test: run upstream hlb-CIFAR10 main.py with PYTHONPATH pointing to the shim and tinygrad, STEPS=5, and confirm it runs end-to-end.
- Iterate: fix missing behaviors discovered by the smoke test and expand coverage until the acceptance criteria are met.
- Optionally: add CI job that runs the smoke test on push.
Files to inspect first
- examples/hlb_cifar10.py (reference: exact API usage and behavior)
- extra/lr_scheduler.py (OneCycleLR expected API)
- tinygrad/nn/state.py (get_state_dict implementation)
- tinygrad/nn/optim* (SGD behavior)
- tinygrad/helpers.py (Context, getenv, BEAM, WINO, colored)
- tinygrad/nn/datasets module (cifar loader)
- any Tensor implementation files for methods like gather, batchnorm, conv2d, pad, shard_/to_
Risks / notes
- Some ops are performance-sensitive or backend-specific (sharding, multi-GPU strings, TinyJit, Context flags). For a first pass, focus on API correctness and smoke-running with single device.
- whitening() in the example uses NumPy operations to precompute weights; ensure dtype/shape conversions are compatible.
- EMA and BatchNorm running statistics touch internal parameter naming; get_state_dict must expose consistent names.
Owner / contact
- Owner: @ppaleja (current user)
- Suggested reviewers: maintainers familiar with tinygrad Tensor API and optimizers.
Estimated effort
- Prototype shim + smoke test: 1–2 days
- Full compatibility for multi-GPU/TinyJit polishing: additional days depending on gaps
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request