forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Overview / goal
- Provide a tinygrad-based compatibility shim so upstream
hlb-CIFAR10code can be executed without changing its imports or training loop, while using tinygrad as the backend (a drop-in "tiny torch" shim). - Priority: API compatibility first (smoke runs), performance and multi-GPU/TinyJit improvements later.
##Acceptance criteria (copy from issue)
- Upstream
hlb-CIFAR10's main script can be executed without modifying its source imports/logic and runs a short smoke training (e.g.,STEPS=5) using tinygrad as the backend. - No runtime
AttributeError/ImportErrordue to missing API surface. - Basic training loop completes and produces a numeric loss (no crash).
High-level design
- Implement a small compatibility package (or use existing
extra/torch_backend) that exposes the minimal torch-like API thathlb-CIFAR10expects: top-level objects,nnsubmodule,optimmodule,dtypes, helpers, and any extras (e.g.,OneCycleLR). - Prefer thin wrappers mapping upstream names/signatures to tinygrad implementations rather than reimplementing behaviors.
- Start with single-device correctness and a smoke test. Expand the shim until all required API calls used by the upstream code run cleanly.
Minimal API checklist (observed from the example)
Top-level and global objects
Tensor,Device.DEFAULT,GlobalCounters,TinyJit,Variable(expose these at top-level).
nn / model / datasetsnn.Conv2d,nn.Linear,nn.BatchNorm2d(or compatible replacement).nn.datasets.cifar()loader (or an adapter that returns numpy arrays/Tensors as upstream expects).
state / serializationget_state_dict(model)-> mapping name -> parameter tensor.
optim / scheduleroptim.SGD(params, lr, momentum, nesterov, weight_decay)andOptimizerGroupcompatibility as used by upstream.extra.lr_scheduler.OneCycleLRadapter (thin wrapper to the existingextra/lr_scheduler.py).
dtypes / helpersdtypes.default_float,float32,int,int32.helpers.Context(...),BEAM,WINO,getenv(),colored(),prod().
Tensor methods / operations used by hlb-CIFAR10- Randoms:
randperm,randint,rand - Creation:
arange,ones,zeros - Access & transform:
reshape,cast,float,numpy,astype,one_hot,gather,where,flip,contiguous,expand,reshape,cat,pad,sequential - NN ops:
conv2d,max_pool2d,batchnorm,quick_gelu,log_softmax,argmax - Reductions / math:
mean,sum,mul,div,add,pow,rsqrt - Autograd:
detach,backward,realize,assign - Device / shape:
.deviceattribute,.shapeattribute
Control Tensor.train()context manager andTensor.trainingflag
Extras (optional)extra/bench_logand other small utilities for timing (optional but helpful in smoke tests)
Files to inspect first (high-value)
tinygrad/examples/hlb_cifar10.py— canonical usage; extract exact method signatures and everything it imports (you already have this list).- For quick reference:
tinygrad/examples/hlb_cifar10.py
- For quick reference:
extra/lr_scheduler.py— OneCycleLR expected API.tinygrad/nn/state.py—get_state_dictimplementation.tinygrad/nn/optim*— checkSGDbehavior and parameter group handling.tinygrad/helpers.py—Context,getenv,BEAM,WINO,colored,prod.tinygrad/nn/datasets—cifar()loader.extra/torch_backend/*— examinebackend.py,wrapped_tensor.cpp,test_compile.pyif you plan to reuse the torch shim approach.
Suggested implementation steps (ordered, with a recommended iteration plan)
-
Prepare a reproducible smoke-run script and environment
- Make a small script reproducing upstream invocation but with
STEPS=5, BS=...etc. - Use
PYTHONPATHor a small shim package to ensure the upstream code imports from your compatibility layer. - Example run (single-device smoke):
TINY_BACKEND=1 STEPS=5 BS=32 python3 examples/hlb_cifar10.py
- Goal: get through an entire training loop for
STEPS=5and get a numeric loss with no exceptions.
- Make a small script reproducing upstream invocation but with
-
Start with an adapter / compatibility module
- Implement a minimal
tinytorchortorch_compatmodule (could be placed underextra/torch_compator similar) that re-exports tinygrad symbols under the names expected by upstream. Example exports:from tinygrad import Tensor, TinyJit, Device, GlobalCounters, Variablefrom tinygrad import nn as nn(and adapt or wrap components when signatures differ)
- Provide thin wrappers for:
optim.SGDsignature differences.nn.BatchNorm2d→ map toUnsyncedBatchNormused inhlb_cifar10.pyif needed.OneCycleLR— adapter toextra/lr_scheduler.OneCycleLRwith matching API.
- Implement a minimal
-
Implement missing Tensor methods (only as needed)
- Run the smoke test and see AttributeError traces.
- For each missing method used in
hlb_cifar10.py, implement either in tinygrad code or as a shim function (wrapper returning compatible result). - Prefer implementing missing helpers in the compatibility module if the functionality is simple (e.g., signature adaptation), and implement true ops in tinygrad Tensor if they are fundamental (e.g.,
gather,batchnormif missing).
-
Ensure correct parameter naming /
get_state_dict- Confirm
get_state_dictreturns the param mapping names expected by optimizer / EMA code. - If upstream expects param names in a particular format (e.g.,
module.layer.weight), ensure the tinygrad state exports consistent names.
- Confirm
-
Add a smoke test harness and automated run
- Add a small test script in
extra/compat_tests/smoke_hlb_cifar10.pythat invokestrain_cifar()with smallSTEPSand asserts no crash and numeric loss (e.g., loss is finite). - Make test deterministic by seeding (
Tensor.manual_seedandrandom.seed).
- Add a small test script in
-
Iterate until the smoke test passes
- Each failure should produce a short PR that either:
- Exposes missing API in the compatibility shim, or
- Fixes tinygrad internals (e.g., add
gatheror broadcasting behavior).
- Keep changes small and test-driven.
- Each failure should produce a short PR that either:
-
Optional: upstream drop-in behavior
- Option A (recommended for speed): Add a small import shim so that when the upstream project imports
torchyou can easily configure env var to prefertiny(requires usingextra/torch_backendbuilt-in integration). - Option B: Provide a
tinytorchpackage and instruct users to run upstream code withPYTHONPATH=extra/torch_compat(this is safest and non-invasive).
- Option A (recommended for speed): Add a small import shim so that when the upstream project imports
Concrete debug & repro tips
- Use small batch sizes (
BS) andSTEPS=3-5to get fast feedback. - Use environment variables for toggles and seeds:
Tensor.manual_seed(seed)andrandom.seed(seed)are used inhlb_cifar10.py.
- Run with verbose exceptions and
TORCH_DEBUGif you hit torch/backend interactions. - If you hit numpy / dtype mismatches, ensure conversions are explicit using
.astype(np.float32)orTensor(...).cast(...). - To pinpoint missing API: run the smoke harness, reproduce the
AttributeErrorand trace to the first missing symbol. Implement the smallest shim that satisfies that call.
Integration points already present in repo (useful to reuse)
extra/torch_backend/backend.py— a fairly complete PyTorch compatibility backend you can reuse or take inspiration from.extra/torch_backend/wrapped_tensor.cpp— a compiled extension used to create opaque PyTorch tensors that refer to tinygradTensor. Useful if you want to support drop-intorchsemantics at the C++/PyTorch level.tinygrad/examples/beautiful_mnist_torch.pyandextra/torch_backend/test_compile.py— examples showing how to useTINY_BACKENDand (experimentally)torch.compilewith tinygrad.
Smoke test checklist (what to run locally)
- Ensure minimal dependencies are installed (numpy, PIL if needed).
- Run:
STEPS=5 BS=32 GPUS=1 python3 tinygrad/examples/hlb_cifar10.py- If using the torch shim / privateuse backend:
TINY_BACKEND=1 STEPS=5 BS=32 python3 tinygrad/examples/hlb_cifar10.py
- Expected: no AttributeError/ImportError; training loop completes
STEPSiterations, returns numeric loss.
CI proposal
- Add a CI job that runs the smoke test in a matrix:
- env:
STEPS=5 BS=32(short),GPUS=1 - install minimal dependencies
- run smoke harness
- env:
- Mark the job optional/allowed to fail until you trust the shim.
- For the C++ extension (
wrapped_tensor.cpp) compilation in CI, ensure build tools are available or guard the job to skip compilation if toolchain missing.
Prioritized backlog (small PR-sized tasks)
- Create
extra/torch_compat(orextra/tinytorch) that re-exports and adapts tinygrad symbols. - Add smoke test script and run locally; log the initial failures.
- Implement shims for missing top-level functions and small Tensor methods invoked early (e.g.,
gather,pad,sequential). - Fix
get_state_dictparameter format if required. - Adapter for
optim.SGDparameter groups matching upstream behavior. - Adapter for
OneCycleLRusage patterns. - Optional: make
extra/torch_backenddrop-in whenTINY_BACKEND=1(reuse existing code). - Add CI smoke test.
PR / review checklist
- Minimal, focused changes: each PR should fix a single missing function or small compatibility mismatch.
- Unit / smoke tests added or updated.
- No whitespace-only changes mixed with logic changes (project style).
- Keep lines < 150 chars and 2-space indentation in code changes.
- Document any API incompatibilities that remain.
Risks / known complications
- Some ops (e.g.,
batchnorminternals, efficientconv2d) are sensitive and may require careful tinygrad implementations (for correctness). - Multi-GPU / sharding,
TinyJitinteractions, ortorch.compilemay introduce complex conversion issues — postpone until basic API compatibility is stable. - C++ extension (
wrapped_tensor.cpp) requires a working build toolchain — this can make CI more complex. Consider initially avoiding requiring the compiled extension.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request