MLX-Flash is a batteries-included training framework for Apple's MLX ecosystem. It provides high-level abstractions for building, fine-tuning, and benchmarking models on Apple Silicon without sacrificing the flexibility of MLX primitives.
The initial release focuses on language modeling with GPT-2, LoRA-based adaptation, and production-friendly training loops. Stay tuned for more. ✨
- 🔩 Unified trainer API with callbacks, checkpointing, and scheduler support.
- 🧠 Language modeling toolkit: tokenizer helpers, datasets, and GPT-2 wrapper powered by
mlx-lm. - ⚡ Apple Silicon first: smart defaults for memory usage, mixed precision, and throughput.
- 🧮 Optimization utilities: gradient clipping, linear warmup schedulers, LoRA/QLoRA adapters.
- 📈 Benchmarks & monitoring: memory heuristics, metric tracking, and optional visualization hooks (LoRA on 7B @ seq. 512 runs comfortably on 16–24 GB unified memory with micro-batching).
python -m venv .venv
source .venv/bin/activate
pip install -e .Optional extras:
# Memory monitoring support
pip install .[monitoring]
# Visualization helpers (training curves)
pip install .[viz]
# Rich terminal progress UI
pip install .[ui]
# LoRA / QLoRA adapters (via mlx-lm)
pip install .[adapters]
# Development tooling (linting & tests)
pip install .[dev]Extras are optional feature bundles (see the Setuptools documentation for details). The table below summarises what each extra installs:
| Extra | Purpose | Key dependencies |
|---|---|---|
monitoring |
System memory stats | psutil |
viz |
Training curve plots | matplotlib |
ui |
Rich terminal progress bars | rich |
adapters |
LoRA/QLoRA fusion & HF uploads | mlx-lm, huggingface_hub |
dev |
Linting & tests | pytest, pytest-cov, ruff |
For CI/CD or longer-lived projects, consider capturing a pinned lock file (for
example with pip-tools, uv, or poetry) alongside pyproject.toml to
improve reproducibility.
Requirements: Python 3.10+, macOS on Apple Silicon, and the
mlxruntime. Tested on macOS 15 (Sequoia) with Python 3.10–3.12 and MLX 0.29.x; performance benefits from unified memory on Apple Silicon.
The fastest way to get started is with the provided example script. Prepare a plain-text corpus (one prompt per line) and run:
# One-time conversion from Hugging Face to MLX format
mlx_lm convert --hf-path openai-community/gpt2 --mlx-path models/gpt2-base-mlx# Run from the repository root after `pip install -e .` (or export PYTHONPATH=.)
python -m mlx_flash.examples.gpt2_finetune \
--train-file data/train.txt \
--eval-file data/valid.txt \
--model models/gpt2-base-mlx \
--output-dir checkpoints/gpt2 \
--epochs 3 \
--batch-size 4 \
--block-size 256
# --lora # optional: enable LoRA adapters
# --qlora # optional: 4-bit quantize + LoRAThis command will:
- Load GPT-2 weights via
mlx-lm(usemlx_lm convert --hf-path openai-community/gpt2to download/convert from Hugging Face). See the mlx-lm README for conversion and quantization details. - Build tokenized datasets with automatic padding and masking.
- Train using the high-level
LanguageModelTrainer. - Save periodic checkpoints plus a final
.safetensorssnapshot.
During training you will see a live terminal progress bar (install the ui
extra) alongside logging output.
MLX-Flash lets you toggle low-rank adapters directly from the CLI. Install the
adapter extra (pip install '.[adapters]') to pull in the mlx-lm toolkit,
then use LoRA to fine-tune while keeping the base weights frozen, or enable
QLoRA to quantize the backbone to 4 bits before attaching adapters:
# LoRA
python -m mlx_flash.examples.gpt2_finetune \
--train-file data/train.txt \
--model models/gpt2-base-mlx \
--output-dir checkpoints/gpt2-lora \
--lora --lora-rank 16 --lora-alpha 32
# QLoRA (4-bit quantization + LoRA adapters)
python -m mlx_flash.examples.gpt2_finetune \
--train-file data/train.txt \
--model models/gpt2-base-mlx \
--output-dir checkpoints/gpt2-qlora \
--qlora --lora-rank 16 --qlora-bits 4 --qlora-group-size 64You can fine-tune the adapter hyperparameters via --lora-* flags, and the
quantization behaviour with --qlora-*. See python -m mlx_flash.examples.gpt2_finetune --help for the full list of options.
Every checkpoint is saved as .safetensors by default (MLX supports this format
in most tooling, though certain metadata conventions are still evolving). To fuse the trained
adapters into the base model, optionally de-quantize to full precision, and emit a
Hugging Face-compatible directory (optionally uploading it), supply
--fuse-save-path and --push-to-hf when running the example script:
python -m mlx_flash.examples.gpt2_finetune \
--train-file data/codealpaca/train.txt \
--model ./gpt2-base-mlx \
--output-dir checkpoints/gpt2-codealpaca \
--qlora \
--fuse-save-path fused/gpt2-codealpaca \
--push-to-hf your-hf-username/gpt2-codealpaca \
--hf-branch mainThis produces fused/gpt2-codealpaca with de-quantized .safetensors
weights ready for distribution. Uploading requires huggingface_hub
(included in the adapters extra) and an authenticated CLI (huggingface-cli login).
If you omit --fuse-save-path, no fusion occurs; by default the example script
produces full-precision fused weights. Add --fuse-keep-quantization to skip
de-quantization, or invoke mlx_lm fuse yourself without --de-quantize.
You can also use mlx_lm fuse --upload-repo ... directly if you prefer the
original mlx-lm tooling for publishing.
We ship a synthetic 10k-sample text corpus so you can try MLX-Flash immediately:
# Optional: regenerate data/sample_{train,valid}.txt
python scripts/generate_sample_corpus.py
# Then launch training from the repo root after installing (or export PYTHONPATH=.)
python -m mlx_flash.examples.gpt2_finetune \
--train-file data/sample_train.txt \
--eval-file data/sample_valid.txt \
--model models/gpt2-base-mlx \
--output-dir checkpoints/sample_run \
--epochs 1 \
--batch-size 8 \
--block-size 256Prefer TOML configs? Start from configs/sample_training.toml and adapt the values
for your task.
Please ensure your dataset’s license allows fine-tuning and redistribution, and that the upstream model’s license (see its Hugging Face card) permits sharing derivatives before uploading fused weights.
MLX-Flash complements mlx-lm . Use the table below to map familiar commands to their
MLX-Flash equivalents:
| Workflow | mlx-lm command/API |
MLX-Flash path |
|---|---|---|
| Convert Hugging Face weights to MLX | mlx_lm convert --hf-path repo -q |
Same command; point LanguageModelConfig(model_path=...) to the generated dir |
| Quick generation / chat | mlx_lm.generate, mlx_lm.chat |
Load fused weights (mlx_lm.utils.fetch_from_hub) or call trainer.generate |
| LoRA / QLoRA fine-tuning | mlx_lm.lora CLI or custom trainers |
python -m mlx_flash.examples.gpt2_finetune --lora/--qlora ... |
| Fuse adapters into base model | mlx_lm fuse --model ... --adapter-path ... |
Same command, or pass --fuse-save-path to the example script |
| Push model to Hugging Face | mlx_lm fuse --upload-repo ... |
--push-to-hf repo/name (requires huggingface_hub) |
Key differences:
- Project scaffolding: MLX-Flash bundles the trainer, configs, and callbacks so you can script training runs or import the trainer in your own pipelines without rewriting boilerplate.
- Checkpoints: adapters and fused models are saved as
.safetensorsby default. No need to choose between.npz/.safetensors. - Automation: the example script can optionally fuse, de-quantize, and
upload in one step (
--fuse-save-path,--push-to-hf). - UI & monitoring: Rich progress bars (
pip install '.[ui]'), evaluation hooks, and future metrics integrations come for free.
If you are already comfortable with mlx-lm for conversion and lightweight
serving, keep using it. Reach for MLX-Flash whenever you need a turn-key
training loop, adapters, or an easy path from fine-tuning → fusion →
distribution.
Heads-up: not every
mlx_lm.*sub-command ships in every environment. For example,mlx_lm.chatmay not be installed by default; fall back tomlx_lm.generateor upgrade the package if you rely on those REPL helpers.
| Scenario | Precision / Method | Peak RAM to plan for | M1 16 GB | M2 24 GB | M3 24 GB | M4 32 GB |
|---|---|---|---|---|---|---|
| Inference only (7B) | 4-bit quantized | ~3.5–5 GB | ✅ | ✅ | ✅ | ✅ |
| LoRA + QLoRA FT (7B) | 4-bit base + LoRA | ~7–12 GB | ✅ | ✅ | ✅ | ✅ |
| LoRA FT (7B) | fp16 base + LoRA | ~14–20 GB | ✅ | ✅ | ✅ | |
| Full fine-tune (7B) | fp16 weights/grads/opt | ~70–120 GB | ❌ | ❌ | ❌ | ❌ |
Why these ranges
- 4-bit inference for 7B ≈ 3.5 GB (4 bits/param). Training adds activations + optimizer state + LoRA deltas → roughly ~2×, so ~7 GB is a common QLoRA target.
- Full fp16 fine-tune rule-of-thumb ≈ ~16 GB per 1B params (weights+grads+opt+acts) → 7B in fp16 is >100 GB, not feasible on laptop unified memory.
- LoRA trims trainable parameters drastically (>60% reduction vs full FT), keeping runs feasible on 16–32 GB Macs.
- Tip: MLX lets you mix bit-widths (e.g., 4-bit most layers, 6-bit embeddings/projections) to balance quality & memory.
| Chip (typical UM) | Comfortable LLM FT mode | Max “comfortable” base size | Vision/Audio FT (ViT-B / keyword cls.) | Notes |
|---|---|---|---|---|
| M1 (16 GB) | QLoRA (4-bit) | 7B (LoRA r=8–16) | ✅ (small/medium backbones) | Stay quantized; keep seq len modest; micro-batch + grad accumulation. |
| M2 (24 GB) | QLoRA or fp16-LoRA | 7B (QLoRA) / ~3B (fp16-LoRA) | ✅✅ | +~50% mem BW vs M1 helps throughput; 24 GB gives headroom for longer context. |
| M3 (24 GB) | QLoRA or fp16-LoRA | 7B (QLoRA) / ~3–4B (fp16-LoRA) | ✅✅ | Small-set 7B FT is viable; scale cautiously with longer context/batches. |
| M4 (32 GB) | QLoRA or fp16-LoRA | 7B (QLoRA) / ~5–7B partial fp16-LoRA | ✅✅ | Higher bandwidth SKUs (e.g., Pro/Max) ease longer context & micro-batches. |
Planning ranges for single-epoch SFT with QLoRA on a 7B model, seq_len≈512, mixed precision, micro-batching + grad accumulation:
| Chip | “Quick win” set (≈15–45 min) | “Solid small” set (≈1–2 h) | “Deeper” set (≈3–5 h) |
|---|---|---|---|
| M1 (16 GB) | 5k–20k samples | 20k–60k | 60k–120k |
| M2 (24 GB) | 10k–30k | 30k–100k | 100k–200k |
| M3 (24 GB) | 10k–40k | 40k–120k | 120k–250k |
| M4 (32 GB) | 15k–50k | 50k–150k | 150k–300k |
How to adapt these to your run
- Time scales roughly linearly with total tokens processed:
tokens_per_epoch ≈ n_samples × seq_lensteps/epoch = tokens_per_epoch / (global_batch × seq_len)wall_time ≈ steps × step_time
- Trade micro-batch (RAM) ↔ grad accumulation (time) to fit memory.
- Tiny demo sets (a few thousand samples) can complete in minutes on M3/M4; larger, cleaner sets improve quality.
Note: “Global batch” =
micro_batch × grad_accum × data_parallel_replicas(for local single-device runs,data_parallel_replicas = 1).
The table below summarises typical configurations (LoRA/QLoRA on GPT-2 / 7B class models) and what to expect on Apple Silicon. Numbers are based on public reports from the MLX community and internal smoke tests; adjust for your exact dataset and thermal conditions.
| Unified memory | Suggested setup | Expected runtime* | Notes |
|---|---|---|---|
| 8–12 GB (M1/M2 Air) | batch_size=4, block_size=256, gradient accumulation 4, LoRA rank 16 |
60–90 min for 2k samples | Keep other apps closed; QLoRA reduces activation footprint further |
| 16–24 GB (M2/M3 Pro) | batch_size=6, block_size=384, gradient accumulation 6, LoRA rank 32–48 |
35–60 min for 2k samples | Matches public 7B LoRA write-ups (e.g. Niklas Heidloff) |
| 32 GB+ (M3/M4 Pro/Max) | batch_size=10–12, block_size=512, no accumulation, LoRA rank 48–64 |
20–30 min for 2k samples | Plenty of headroom for longer sequences / larger datasets |
*Measured on CodeAlpaca-mini (≈2k examples). Full 20k runs scale roughly linearly with the number of steps.
from pathlib import Path
from mlx_flash.core.config import TrainingConfig, CheckpointConfig
from mlx_flash.data import DataLoader
from mlx_flash.language import (
LanguageModelDataset,
LanguageModelTrainer,
LoRAConfig,
TokenizerWrapper,
TokenizerConfig,
GPT2LMHeadWrapper,
LanguageModelConfig,
)
# Tokenizer & datasets
tokenizer = TokenizerWrapper(TokenizerConfig(name_or_path="mlx-community/gpt2"))
train_dataset = LanguageModelDataset(["Hello MLX world!"], tokenizer, block_size=128)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# Model & trainer
model = GPT2LMHeadWrapper(LanguageModelConfig(model_path="mlx-community/gpt2"))
config = TrainingConfig(
num_epochs=1,
checkpoint=CheckpointConfig(output_dir=Path("checkpoints")),
)
trainer = LanguageModelTrainer(model, train_loader, config=config)
trainer.train()
print(trainer.generate("Once upon a time", max_new_tokens=32))
# Enable LoRA adapters programmatically
peft_model = GPT2LMHeadWrapper(
LanguageModelConfig(
model_path="mlx-community/gpt2",
lora=LoRAConfig(rank=16, alpha=32, dropout=0.05),
)
)mlx_lm.chatnot found? Install a recent version ofmlx-lmor usemlx_lm.generateas an alternative REPL.- Fused model size looks unchanged. The example script de-quantizes by
default; add
--fuse-keep-quantization(or usemlx_lm fusewithout--de-quantize) to preserve the original quantized weights. - Which GPT-2 should I fine-tune? Convert any Hugging Face GPT‑2 with
mlx_lm convert --hf-path openai-community/gpt2 --mlx-path models/gpt2-base-mlx.
mlx_flash/
├── core/ # Trainer, callbacks, configuration, logging
├── language/ # GPT-2 wrappers, datasets, tokenizer helpers
├── models/ # Base wrappers and registry
├── optimization/ # Optimizers, losses, learning-rate schedulers
├── data/ # Lightweight dataloader
├── utils/ # Memory, metrics, and visualization helpers
├── benchmarks/ # Micro-benchmarks for throughput
└── examples/ # End-to-end fine-tuning scripts
pytest -qSmall unit tests ensure the trainer loop, dataset pipeline, and configuration helpers work as expected. Add the dev extra to install all testing dependencies.
We are just starting MLX-Flash will keep growing alongside the MLX ecosystem. In the meantime, check out community projects such as mlx-vlm and the broader MLX GitHub org for a glimpse of what’s coming.
Contributions, issues, and feedback are welcome!
If you build on MLX-Flash, please cite the project:
@misc{Curious Programmer,
author = {S Susant Achary},
title = {MLX-Flash},
year = {2025},
}
Released under the MIT License.