Skip to content

SSusantAchary/mlx-flash

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

⚡ MLX-Flash: Train Anything on Apple Silicon.

MLX-Flash is a batteries-included training framework for Apple's MLX ecosystem. It provides high-level abstractions for building, fine-tuning, and benchmarking models on Apple Silicon without sacrificing the flexibility of MLX primitives.

The initial release focuses on language modeling with GPT-2, LoRA-based adaptation, and production-friendly training loops. Stay tuned for more. ✨

✨ Highlights

  • 🔩 Unified trainer API with callbacks, checkpointing, and scheduler support.
  • 🧠 Language modeling toolkit: tokenizer helpers, datasets, and GPT-2 wrapper powered by mlx-lm.
  • Apple Silicon first: smart defaults for memory usage, mixed precision, and throughput.
  • 🧮 Optimization utilities: gradient clipping, linear warmup schedulers, LoRA/QLoRA adapters.
  • 📈 Benchmarks & monitoring: memory heuristics, metric tracking, and optional visualization hooks (LoRA on 7B @ seq. 512 runs comfortably on 16–24 GB unified memory with micro-batching).

📦 Installation

python -m venv .venv
source .venv/bin/activate
pip install -e .

Optional extras:

# Memory monitoring support
pip install .[monitoring]

# Visualization helpers (training curves)
pip install .[viz]

# Rich terminal progress UI
pip install .[ui]

# LoRA / QLoRA adapters (via mlx-lm)
pip install .[adapters]

# Development tooling (linting & tests)
pip install .[dev]

Extras are optional feature bundles (see the Setuptools documentation for details). The table below summarises what each extra installs:

Extra Purpose Key dependencies
monitoring System memory stats psutil
viz Training curve plots matplotlib
ui Rich terminal progress bars rich
adapters LoRA/QLoRA fusion & HF uploads mlx-lm, huggingface_hub
dev Linting & tests pytest, pytest-cov, ruff

For CI/CD or longer-lived projects, consider capturing a pinned lock file (for example with pip-tools, uv, or poetry) alongside pyproject.toml to improve reproducibility.

Requirements: Python 3.10+, macOS on Apple Silicon, and the mlx runtime. Tested on macOS 15 (Sequoia) with Python 3.10–3.12 and MLX 0.29.x; performance benefits from unified memory on Apple Silicon.

Quickstart: Fine-Tune GPT-2

The fastest way to get started is with the provided example script. Prepare a plain-text corpus (one prompt per line) and run:

# One-time conversion from Hugging Face to MLX format
mlx_lm convert --hf-path openai-community/gpt2 --mlx-path models/gpt2-base-mlx
# Run from the repository root after `pip install -e .` (or export PYTHONPATH=.)
python -m mlx_flash.examples.gpt2_finetune \
  --train-file data/train.txt \
  --eval-file data/valid.txt \
  --model models/gpt2-base-mlx \
  --output-dir checkpoints/gpt2 \
  --epochs 3 \
  --batch-size 4 \
  --block-size 256
#   --lora               # optional: enable LoRA adapters
#   --qlora              # optional: 4-bit quantize + LoRA

This command will:

  1. Load GPT-2 weights via mlx-lm (use mlx_lm convert --hf-path openai-community/gpt2 to download/convert from Hugging Face). See the mlx-lm README for conversion and quantization details.
  2. Build tokenized datasets with automatic padding and masking.
  3. Train using the high-level LanguageModelTrainer.
  4. Save periodic checkpoints plus a final .safetensors snapshot.

During training you will see a live terminal progress bar (install the ui extra) alongside logging output.

Parameter-efficient adapters (LoRA & QLoRA)

MLX-Flash lets you toggle low-rank adapters directly from the CLI. Install the adapter extra (pip install '.[adapters]') to pull in the mlx-lm toolkit, then use LoRA to fine-tune while keeping the base weights frozen, or enable QLoRA to quantize the backbone to 4 bits before attaching adapters:

# LoRA
python -m mlx_flash.examples.gpt2_finetune \
  --train-file data/train.txt \
  --model models/gpt2-base-mlx \
  --output-dir checkpoints/gpt2-lora \
  --lora --lora-rank 16 --lora-alpha 32

# QLoRA (4-bit quantization + LoRA adapters)
python -m mlx_flash.examples.gpt2_finetune \
  --train-file data/train.txt \
  --model models/gpt2-base-mlx \
  --output-dir checkpoints/gpt2-qlora \
  --qlora --lora-rank 16 --qlora-bits 4 --qlora-group-size 64

You can fine-tune the adapter hyperparameters via --lora-* flags, and the quantization behaviour with --qlora-*. See python -m mlx_flash.examples.gpt2_finetune --help for the full list of options.

Export fused full-precision weights & upload to Hugging Face

Every checkpoint is saved as .safetensors by default (MLX supports this format in most tooling, though certain metadata conventions are still evolving). To fuse the trained adapters into the base model, optionally de-quantize to full precision, and emit a Hugging Face-compatible directory (optionally uploading it), supply --fuse-save-path and --push-to-hf when running the example script:

python -m mlx_flash.examples.gpt2_finetune \
  --train-file data/codealpaca/train.txt \
  --model ./gpt2-base-mlx \
  --output-dir checkpoints/gpt2-codealpaca \
  --qlora \
  --fuse-save-path fused/gpt2-codealpaca \
  --push-to-hf your-hf-username/gpt2-codealpaca \
  --hf-branch main

This produces fused/gpt2-codealpaca with de-quantized .safetensors weights ready for distribution. Uploading requires huggingface_hub (included in the adapters extra) and an authenticated CLI (huggingface-cli login). If you omit --fuse-save-path, no fusion occurs; by default the example script produces full-precision fused weights. Add --fuse-keep-quantization to skip de-quantization, or invoke mlx_lm fuse yourself without --de-quantize. You can also use mlx_lm fuse --upload-repo ... directly if you prefer the original mlx-lm tooling for publishing.

Use the bundled sample corpus

We ship a synthetic 10k-sample text corpus so you can try MLX-Flash immediately:

# Optional: regenerate data/sample_{train,valid}.txt
python scripts/generate_sample_corpus.py

# Then launch training from the repo root after installing (or export PYTHONPATH=.)
python -m mlx_flash.examples.gpt2_finetune \
  --train-file data/sample_train.txt \
  --eval-file data/sample_valid.txt \
  --model models/gpt2-base-mlx \
  --output-dir checkpoints/sample_run \
  --epochs 1 \
  --batch-size 8 \
  --block-size 256

Prefer TOML configs? Start from configs/sample_training.toml and adapt the values for your task.

Data & licensing

Please ensure your dataset’s license allows fine-tuning and redistribution, and that the upstream model’s license (see its Hugging Face card) permits sharing derivatives before uploading fused weights.

Coming from MLX-LM

MLX-Flash complements mlx-lm . Use the table below to map familiar commands to their MLX-Flash equivalents:

Workflow mlx-lm command/API MLX-Flash path
Convert Hugging Face weights to MLX mlx_lm convert --hf-path repo -q Same command; point LanguageModelConfig(model_path=...) to the generated dir
Quick generation / chat mlx_lm.generate, mlx_lm.chat Load fused weights (mlx_lm.utils.fetch_from_hub) or call trainer.generate
LoRA / QLoRA fine-tuning mlx_lm.lora CLI or custom trainers python -m mlx_flash.examples.gpt2_finetune --lora/--qlora ...
Fuse adapters into base model mlx_lm fuse --model ... --adapter-path ... Same command, or pass --fuse-save-path to the example script
Push model to Hugging Face mlx_lm fuse --upload-repo ... --push-to-hf repo/name (requires huggingface_hub)

Key differences:

  • Project scaffolding: MLX-Flash bundles the trainer, configs, and callbacks so you can script training runs or import the trainer in your own pipelines without rewriting boilerplate.
  • Checkpoints: adapters and fused models are saved as .safetensors by default. No need to choose between .npz/.safetensors.
  • Automation: the example script can optionally fuse, de-quantize, and upload in one step (--fuse-save-path, --push-to-hf).
  • UI & monitoring: Rich progress bars (pip install '.[ui]'), evaluation hooks, and future metrics integrations come for free.

If you are already comfortable with mlx-lm for conversion and lightweight serving, keep using it. Reach for MLX-Flash whenever you need a turn-key training loop, adapters, or an easy path from fine-tuning → fusion → distribution.

Heads-up: not every mlx_lm.* sub-command ships in every environment. For example, mlx_lm.chat may not be installed by default; fall back to mlx_lm.generate or upgrade the package if you rely on those REPL helpers.

📏 Memory & Throughput Planning (Apple Silicon • MLX)

A) Memory headroom you should plan for (LLMs; 7B as a reference)

Scenario Precision / Method Peak RAM to plan for M1 16 GB M2 24 GB M3 24 GB M4 32 GB
Inference only (7B) 4-bit quantized ~3.5–5 GB
LoRA + QLoRA FT (7B) 4-bit base + LoRA ~7–12 GB
LoRA FT (7B) fp16 base + LoRA ~14–20 GB ⚠️ tight
Full fine-tune (7B) fp16 weights/grads/opt ~70–120 GB

Why these ranges

  • 4-bit inference for 7B ≈ 3.5 GB (4 bits/param). Training adds activations + optimizer state + LoRA deltas → roughly ~2×, so ~7 GB is a common QLoRA target.
  • Full fp16 fine-tune rule-of-thumb ≈ ~16 GB per 1B params (weights+grads+opt+acts) → 7B in fp16 is >100 GB, not feasible on laptop unified memory.
  • LoRA trims trainable parameters drastically (>60% reduction vs full FT), keeping runs feasible on 16–32 GB Macs.
  • Tip: MLX lets you mix bit-widths (e.g., 4-bit most layers, 6-bit embeddings/projections) to balance quality & memory.

B) Quick planning guide by chip (what’s comfortable to run)

Chip (typical UM) Comfortable LLM FT mode Max “comfortable” base size Vision/Audio FT (ViT-B / keyword cls.) Notes
M1 (16 GB) QLoRA (4-bit) 7B (LoRA r=8–16) ✅ (small/medium backbones) Stay quantized; keep seq len modest; micro-batch + grad accumulation.
M2 (24 GB) QLoRA or fp16-LoRA 7B (QLoRA) / ~3B (fp16-LoRA) ✅✅ +~50% mem BW vs M1 helps throughput; 24 GB gives headroom for longer context.
M3 (24 GB) QLoRA or fp16-LoRA 7B (QLoRA) / ~3–4B (fp16-LoRA) ✅✅ Small-set 7B FT is viable; scale cautiously with longer context/batches.
M4 (32 GB) QLoRA or fp16-LoRA 7B (QLoRA) / ~5–7B partial fp16-LoRA ✅✅ Higher bandwidth SKUs (e.g., Pro/Max) ease longer context & micro-batches.

C) Dataset size & time-to-train planning (rule-of-thumb)

Planning ranges for single-epoch SFT with QLoRA on a 7B model, seq_len≈512, mixed precision, micro-batching + grad accumulation:

Chip “Quick win” set (≈15–45 min) “Solid small” set (≈1–2 h) “Deeper” set (≈3–5 h)
M1 (16 GB) 5k–20k samples 20k–60k 60k–120k
M2 (24 GB) 10k–30k 30k–100k 100k–200k
M3 (24 GB) 10k–40k 40k–120k 120k–250k
M4 (32 GB) 15k–50k 50k–150k 150k–300k

How to adapt these to your run

  • Time scales roughly linearly with total tokens processed:
    • tokens_per_epoch ≈ n_samples × seq_len
    • steps/epoch = tokens_per_epoch / (global_batch × seq_len)
    • wall_time ≈ steps × step_time
  • Trade micro-batch (RAM)grad accumulation (time) to fit memory.
  • Tiny demo sets (a few thousand samples) can complete in minutes on M3/M4; larger, cleaner sets improve quality.

Note: “Global batch” = micro_batch × grad_accum × data_parallel_replicas (for local single-device runs, data_parallel_replicas = 1).

The table below summarises typical configurations (LoRA/QLoRA on GPT-2 / 7B class models) and what to expect on Apple Silicon. Numbers are based on public reports from the MLX community and internal smoke tests; adjust for your exact dataset and thermal conditions.

Unified memory Suggested setup Expected runtime* Notes
8–12 GB (M1/M2 Air) batch_size=4, block_size=256, gradient accumulation 4, LoRA rank 16 60–90 min for 2k samples Keep other apps closed; QLoRA reduces activation footprint further
16–24 GB (M2/M3 Pro) batch_size=6, block_size=384, gradient accumulation 6, LoRA rank 32–48 35–60 min for 2k samples Matches public 7B LoRA write-ups (e.g. Niklas Heidloff)
32 GB+ (M3/M4 Pro/Max) batch_size=10–12, block_size=512, no accumulation, LoRA rank 48–64 20–30 min for 2k samples Plenty of headroom for longer sequences / larger datasets

*Measured on CodeAlpaca-mini (≈2k examples). Full 20k runs scale roughly linearly with the number of steps.

Programmatic Usage

from pathlib import Path

from mlx_flash.core.config import TrainingConfig, CheckpointConfig
from mlx_flash.data import DataLoader
from mlx_flash.language import (
    LanguageModelDataset,
    LanguageModelTrainer,
    LoRAConfig,
    TokenizerWrapper,
    TokenizerConfig,
    GPT2LMHeadWrapper,
    LanguageModelConfig,
)

# Tokenizer & datasets
tokenizer = TokenizerWrapper(TokenizerConfig(name_or_path="mlx-community/gpt2"))
train_dataset = LanguageModelDataset(["Hello MLX world!"], tokenizer, block_size=128)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# Model & trainer
model = GPT2LMHeadWrapper(LanguageModelConfig(model_path="mlx-community/gpt2"))
config = TrainingConfig(
    num_epochs=1,
    checkpoint=CheckpointConfig(output_dir=Path("checkpoints")),
)

trainer = LanguageModelTrainer(model, train_loader, config=config)
trainer.train()
print(trainer.generate("Once upon a time", max_new_tokens=32))

# Enable LoRA adapters programmatically
peft_model = GPT2LMHeadWrapper(
    LanguageModelConfig(
        model_path="mlx-community/gpt2",
        lora=LoRAConfig(rank=16, alpha=32, dropout=0.05),
    )
)

Troubleshooting FAQ

  • mlx_lm.chat not found? Install a recent version of mlx-lm or use mlx_lm.generate as an alternative REPL.
  • Fused model size looks unchanged. The example script de-quantizes by default; add --fuse-keep-quantization (or use mlx_lm fuse without --de-quantize) to preserve the original quantized weights.
  • Which GPT-2 should I fine-tune? Convert any Hugging Face GPT‑2 with mlx_lm convert --hf-path openai-community/gpt2 --mlx-path models/gpt2-base-mlx.

Project Layout

mlx_flash/
├── core/          # Trainer, callbacks, configuration, logging
├── language/      # GPT-2 wrappers, datasets, tokenizer helpers
├── models/        # Base wrappers and registry
├── optimization/  # Optimizers, losses, learning-rate schedulers
├── data/          # Lightweight dataloader
├── utils/         # Memory, metrics, and visualization helpers
├── benchmarks/    # Micro-benchmarks for throughput
└── examples/      # End-to-end fine-tuning scripts

Testing

pytest -q

Small unit tests ensure the trainer loop, dataset pipeline, and configuration helpers work as expected. Add the dev extra to install all testing dependencies.

🚧 Looking Ahead (WIP)

We are just starting MLX-Flash will keep growing alongside the MLX ecosystem. In the meantime, check out community projects such as mlx-vlm and the broader MLX GitHub org for a glimpse of what’s coming.

Contributions, issues, and feedback are welcome!

Citation

If you build on MLX-Flash, please cite the project:

@misc{Curious Programmer,
  author = {S Susant Achary},
  title  = {MLX-Flash},
  year   = {2025},
}

Released under the MIT License.

About

Simple configs. Smart defaults. Solid results on Mac.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages