From d1eb99103caa429bf24662a51dd36611bdd88a62 Mon Sep 17 00:00:00 2001 From: Rob Mulla Date: Tue, 16 Dec 2025 11:07:45 -0500 Subject: [PATCH 1/5] docs: streamline first run tutorial and landing page --- docs/tutorials.md | 29 +++++++++++++++++++ docs/tutorials/first_run.md | 57 ++++++++++++++++++++++++++----------- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/docs/tutorials.md b/docs/tutorials.md index 090bb00b0..8c3de2d6c 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -16,8 +16,37 @@ # Tutorials +Welcome to the MaxText tutorials! If you are new here, we recommend starting with the **Getting Started** guide below. + +## 🚀 New to MaxText? + +If you haven't installed MaxText yet, please check our [Installation Guide](install_maxtext.md) first. + +Once installed, follow our **First Run** tutorial to get your first model training on a TPU: + +* **[Getting Started: First Run](tutorials/first_run.md)** + * *Goal:* Train your first LLM model on a single TPU. + * *Time:* ~15 minutes. + * *Topic:* Cloud Storage setup, Running `train.py`, and Verify with `decode.py`. + +--- + +## All Tutorials + +Below is a list of all available tutorials organized by topic. + +### Basics +* [Getting Started: First Run](tutorials/first_run.md) - The essential starting point. + +### Training +* [Pre-training](tutorials/pretraining.md) - Learn how to run large-scale pre-training jobs. + +### Post-Training +* [Post-Training Index](tutorials/post_training_index.md) - a collection of guides for fine-tuning, RLHF, and other post-training workflows. + ```{toctree} :maxdepth: 1 +:hidden: tutorials/first_run.md tutorials/pretraining.md diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 9e16f034f..5f8a44c20 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -20,9 +20,23 @@ This topic provides a basic introduction to get your MaxText workload up and running on single host and multihost environments using Cloud TPUs or NVIDIA GPUs. To help you get familiar with MaxText, we recommend starting with a single host first and then moving to multihost. ## Prerequisites: Set up storage and configure MaxText -1. To store logs and checkpoints, [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) in your project. To run MaxText, the TPU or GPU VMs must have read/write permissions for the bucket. These permissions are granted by service account roles, such as the `STORAGE ADMIN` role. -2. MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in `configs/base.yml`. This file includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. Set `base_output_directory` to a folder in the bucket you just created. +### 1. Set up Environment Variables +To make the commands in this tutorial copy-paste friendly, set your bucket name as an environment variable: +```bash +export BUCKET_NAME=your-bucket-name +# Example: export BUCKET_NAME=maxtext-test-runs +``` + +### 2. Create Cloud Storage Bucket +Create a bucket to store logs and checkpoints: +```bash +gcloud storage buckets create gs://${BUCKET_NAME} --location=us-central1 +# Note: Ensure your TPU VM has read/write access to this bucket. +``` + +### 3. Configuration +MaxText reads a yaml file for configuration. The default configuration is in `configs/base.yml` (decoder-only model, ~1B params). We will override specific parameters (like `base_output_directory`) via command line arguments. ## Local development for single host This procedure describes how to run MaxText on a single GPU or TPU host. @@ -33,31 +47,42 @@ multiple hosts but is a good way to learn about MaxText. 1. [Create and SSH to the single host VM of your choice](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm). You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. 2. Clone MaxText onto that TPU VM. -3. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: -```sh +3. Make sure you are in the `MaxText` root directory (where `setup.sh` was run): +```bash +cd ~/MaxText +``` + +4. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: +```bash python3 -m venv ~/venv-maxtext source ~/venv-maxtext/bin/activate bash tools/setup/setup.sh pre-commit install ``` -4. After installation completes, run training on synthetic data with the following command: -```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ - run_name=$YOUR_JOB_NAME \ - base_output_directory=gs:// \ +5. After installation completes, run training on synthetic data with the following command: +```bash +# Set a unique run name +export RUN_NAME=run_$(date +%Y%m%d_%H%M%S) + +# Run training +python3 -m MaxText.train configs/base.yml \ + run_name=$RUN_NAME \ + base_output_directory=gs://${BUCKET_NAME} \ dataset_type=synthetic \ steps=10 ``` -Optional: If you want to try training on a Hugging Face dataset, see [Data Input Pipeline](../guides/data_input_pipeline.md) for data input options. +Optional: If you want to try training on a Hugging Face dataset, see [Data Input Pipeline](data-input-pipeline) for data input options. -5. To demonstrate model output, run the following command: -```sh -python3 -m MaxText.decode src/MaxText/configs/base.yml \ - run_name=$YOUR_JOB_NAME \ - base_output_directory=gs:// \ +6. To demonstrate model output, we can run decoding (inference). +> **Note:** We use the same `RUN_NAME` and `BUCKET_NAME` to automatically load the checkpoint we just trained. + +```bash +python3 -m MaxText.decode configs/base.yml \ + run_name=$RUN_NAME \ + base_output_directory=gs://${BUCKET_NAME} \ per_device_batch_size=1 ``` -This command uses a model with randomly initialized weights, so the outputs are also random. To get high quality output you need pass in a checkpoint, typically via the `load_parameters_path` argument. +This command uses the checkpoint from the training run. If no checkpoint is found (e.g. if you changed the run name), it will initialize with random weights. ### Run MaxText via notebook From 6bf904751dcd841500f4ce7f454f1b8322f1b59c Mon Sep 17 00:00:00 2001 From: Rob Mulla Date: Tue, 16 Dec 2025 11:21:15 -0500 Subject: [PATCH 2/5] docs: streamline first_run config explanation --- docs/tutorials/first_run.md | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 5f8a44c20..1f92c1383 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -29,14 +29,32 @@ export BUCKET_NAME=your-bucket-name ``` ### 2. Create Cloud Storage Bucket -Create a bucket to store logs and checkpoints: +To store logs and checkpoints, [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) in your project. To run MaxText, the TPU or GPU VMs must have read/write permissions for the bucket. These permissions are granted by service account roles, such as the `STORAGE ADMIN` role. + ```bash gcloud storage buckets create gs://${BUCKET_NAME} --location=us-central1 # Note: Ensure your TPU VM has read/write access to this bucket. ``` ### 3. Configuration -MaxText reads a yaml file for configuration. The default configuration is in `configs/base.yml` (decoder-only model, ~1B params). We will override specific parameters (like `base_output_directory`) via command line arguments. +MaxText reads a yaml file for configuration. The default configuration is in `configs/base.yml`. +**Note:** This default configuration defines a **Decoder-only** model with **~1B parameters**. + +You can review the full [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml) file, but here are the key configurable options: + +```yaml +# configs/base.yml snippet +model_name: "default" +weight_dtype: "float32" +base_emb_dim: 2048 +base_num_query_heads: 16 +base_num_kv_heads: 16 +base_mlp_dim: 7168 +base_num_decoder_layers: 16 +head_dim: 128 +``` + +The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. ## Local development for single host This procedure describes how to run MaxText on a single GPU or TPU host. From 4f29481345e9f315607a37dff1ca9d20713cc00f Mon Sep 17 00:00:00 2001 From: Rob Mulla Date: Tue, 16 Dec 2025 11:32:23 -0500 Subject: [PATCH 3/5] docs: render demo_decoding notebook directly --- docs/conf.py | 3 + docs/tutorials.md | 20 +- docs/tutorials/demo_decoding.ipynb | 438 +++++++++++++++++++++++++++++ docs/tutorials/first_run.md | 2 +- 4 files changed, 447 insertions(+), 16 deletions(-) create mode 100644 docs/tutorials/demo_decoding.ipynb diff --git a/docs/conf.py b/docs/conf.py index 21f27650a..fbc2ece3d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,6 +59,9 @@ ] myst_linkify_fuzzy_links = False +# Notebook execution mode +nb_execution_mode = "off" + # Theme-specific options # https://sphinx-book-theme.readthedocs.io/en/stable/reference.html html_theme_options = { diff --git a/docs/tutorials.md b/docs/tutorials.md index 8c3de2d6c..2d8a6b096 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -18,37 +18,27 @@ Welcome to the MaxText tutorials! If you are new here, we recommend starting with the **Getting Started** guide below. -## 🚀 New to MaxText? +## New to MaxText? If you haven't installed MaxText yet, please check our [Installation Guide](install_maxtext.md) first. -Once installed, follow our **First Run** tutorial to get your first model training on a TPU: - -* **[Getting Started: First Run](tutorials/first_run.md)** - * *Goal:* Train your first LLM model on a single TPU. - * *Time:* ~15 minutes. - * *Topic:* Cloud Storage setup, Running `train.py`, and Verify with `decode.py`. +Once installed, follow our **[First Run](tutorials/first_run.md)** tutorial to get your first model training on a TPU. This tutorial will guide you through the process of training a model with MaxText and verifying the results. --- -## All Tutorials - Below is a list of all available tutorials organized by topic. -### Basics -* [Getting Started: First Run](tutorials/first_run.md) - The essential starting point. - -### Training +### Additional Tutorials * [Pre-training](tutorials/pretraining.md) - Learn how to run large-scale pre-training jobs. - -### Post-Training * [Post-Training Index](tutorials/post_training_index.md) - a collection of guides for fine-tuning, RLHF, and other post-training workflows. ```{toctree} :maxdepth: 1 :hidden: +:caption: Tutorials tutorials/first_run.md tutorials/pretraining.md tutorials/post_training_index.md +tutorials/demo_decoding.ipynb ``` diff --git a/docs/tutorials/demo_decoding.ipynb b/docs/tutorials/demo_decoding.ipynb new file mode 100644 index 000000000..9b913318e --- /dev/null +++ b/docs/tutorials/demo_decoding.ipynb @@ -0,0 +1,438 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e017d77b", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb)\n", + " \n", + "# Qwen3-0.6B Decoding Demo" + ] + }, + { + "cell_type": "markdown", + "id": "dc85cefe-8f29-47db-a8f3-4e8fbb354eb5", + "metadata": {}, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "id": "55e3ce9e-8968-4d68-ba2b-b36c616b52a9", + "metadata": {}, + "source": [ + "### Change Runtime Type\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "id": "bf5e0f3f-5833-4260-a31d-b156249d67ab", + "metadata": {}, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need to paste it in the next step.\n", + "\n", + "**Follow these steps to store your token:**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "markdown", + "id": "8a6deec5-b64a-4bc6-86c4-c24696c66f17", + "metadata": {}, + "source": [ + "## Installation: MaxText & Other Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b2d4a66-99de-404c-aac3-18b1af4af78e", + "metadata": {}, + "outputs": [], + "source": [ + "# Install uv, a fast Python package installer\n", + "!pip install uv\n", + "\n", + "# Install MaxText and dependencies\n", + "!uv pip install maxtext --resolution=lowest\n", + "!python3 -m MaxText.install_maxtext_extra_deps\n", + "\n", + "# Use nest_asyncio to allow nested event loops in notebooks\n", + "!uv pip install nest_asyncio\n", + "\n", + "# Install the PyTorch library\n", + "!uv pip install torch" + ] + }, + { + "cell_type": "markdown", + "id": "5a07fd61-35b7-4aa9-93cd-49ef89fb550d", + "metadata": {}, + "source": [ + "### Restart Session\n", + "To apply certain changes, you need to restart the session.\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "id": "2f1ebdb1-dcf4-417b-9c29-3461e06aa9cf", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8e986cb", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "import datetime\n", + "import jax\n", + "import os\n", + "import nest_asyncio\n", + "import numpy as np\n", + "\n", + "import MaxText as mt\n", + "from MaxText import common_types\n", + "from MaxText import inference_utils\n", + "from MaxText import maxtext_utils\n", + "from MaxText import max_logging\n", + "from MaxText import pyconfig\n", + "from MaxText.input_pipeline import _input_pipeline_utils\n", + "from MaxText.utils.ckpt_conversion import to_maxtext\n", + "\n", + "from google.colab import userdata\n", + "from huggingface_hub import login\n", + "\n", + "MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "markdown", + "id": "c4f53124", + "metadata": {}, + "source": [ + "## Sanity Test: Checking for Available TPU Devices" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a545acd8", + "metadata": {}, + "outputs": [], + "source": [ + "jax.distributed.initialize() # distributed.initialize should only be called once.\n", + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "id": "be0113d9-0cb6-45aa-9fa4-7e543db7645e", + "metadata": {}, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b80080fa-473b-4683-b0c9-765af43efd49", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"qwen3-0.6b\"\n", + "PROMPT = \"I love to\"\n", + "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", + "MODEL_CHECKPOINT_PATH = f\"/tmp/checkpoints/{MODEL_NAME}/{RUN_NAME}/unscanned\"\n", + "\n", + "HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "login(token=HF_TOKEN)\n", + "max_logging.log(\"Authenticated with Hugging Face successfully!\")" + ] + }, + { + "cell_type": "markdown", + "id": "03ff53b8-b931-4190-bcac-d6ca885cbbc8", + "metadata": {}, + "source": [ + "## Download Model Checkpoint From Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fd3578c-5763-410e-8b61-72d7415628bd", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "argv = [\n", + " \"\", # This is a placeholder, it's not actually used by the script's logic\n", + " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"base_output_directory={MODEL_CHECKPOINT_PATH}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " \"use_multimodal=false\",\n", + " \"scan_layers=false\",\n", + "]\n", + "\n", + "to_maxtext.main(argv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94a9fd37-95e5-4075-837e-de0f1666d55f", + "metadata": {}, + "outputs": [], + "source": [ + "max_logging.log(f\"Model checkpoint can be found at: {MODEL_CHECKPOINT_PATH}/0/items\")" + ] + }, + { + "cell_type": "markdown", + "id": "0cf4bbe4-6485-4ce7-8aef-cf3df3810e52", + "metadata": {}, + "source": [ + "## Initialize Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32f44079-87ae-4ed4-a008-c0dbb8aaf8c0", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "config = pyconfig.initialize(\n", + " [\"\", f\"{MAXTEXT_PKG_DIR}/configs/base.yml\"],\n", + " per_device_batch_size=1.0,\n", + " run_name=\"test\",\n", + " max_target_length=4,\n", + " max_prefill_predict_length=4,\n", + " tokenizer_path=f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n", + " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}/0/items\",\n", + " model_name=MODEL_NAME,\n", + " async_checkpointing=False,\n", + " prompt=PROMPT,\n", + " scan_layers=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a637f43b-6fcc-4305-af5b-d9d30d464bb6", + "metadata": {}, + "outputs": [], + "source": [ + "max_logging.log(\"Decode configurations initialized.\")" + ] + }, + { + "cell_type": "markdown", + "id": "cd502094-1694-410a-91e9-25bbb8dfb33a", + "metadata": {}, + "source": [ + "## Initialize Decode State" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2d2de93", + "metadata": {}, + "outputs": [], + "source": [ + "model = mt.from_config(config)\n", + "mesh = model.mesh\n", + "init_rng = jax.random.PRNGKey(config.init_weights_seed)\n", + "state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)\n", + "max_logging.log(\"Decode state initialized.\")" + ] + }, + { + "cell_type": "markdown", + "id": "ed4b59a7", + "metadata": {}, + "source": [ + "## Get Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35584129-3c45-45ad-b2a2-a56f98d27f06", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = _input_pipeline_utils.get_tokenizer(\n", + " f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n", + " \"huggingface\",\n", + " add_bos=True,\n", + " add_eos=False,\n", + ")\n", + "max_logging.log(\"Tokenizer loaded succuessfully.\")" + ] + }, + { + "cell_type": "markdown", + "id": "32a252ae", + "metadata": {}, + "source": [ + "## Prepare Inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2d2d0c5", + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = tokenizer.encode(config.prompt)\n", + "\n", + "# Pad input_ids to max_target_length\n", + "padded_ids = np.zeros(config.max_target_length, dtype=np.int32)\n", + "padded_ids[: len(input_ids)] = input_ids\n", + "ids = np.asarray(padded_ids, dtype=np.int32)\n", + "\n", + "s = (config.global_batch_size_to_train_on, config.max_target_length)\n", + "decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR\n", + "decoder_positions = np.stack(\n", + " [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]\n", + ")\n", + "\n", + "ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])\n", + "max_logging.log(\n", + " f\"input_ids={input_ids}, \\n\\nids={ids}, \\n\\ndecoder_segment_ids = {decoder_segment_ids}, \\n\\ndecoder_positions= {decoder_positions}\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "647018c1", + "metadata": {}, + "source": [ + "## Run Forward Pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7436751b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "full_train_logits = model.apply(\n", + " state.params,\n", + " ids,\n", + " decoder_positions,\n", + " decoder_segment_ids,\n", + " enable_dropout=False,\n", + " rngs={\"aqt\": init_rng},\n", + ")\n", + "full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)\n", + "max_logging.log(f\"{full_train_logits[0, 0, :]=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5640ab55", + "metadata": {}, + "source": [ + "## Generate Text with Greedy Decoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb06c0c9", + "metadata": {}, + "outputs": [], + "source": [ + "selected_logits = jax.lax.dynamic_slice(\n", + " full_train_logits, (0, 0, full_train_logits.shape[2] - 2, 0), (1, 1, 1, full_train_logits.shape[3])\n", + ")\n", + "\n", + "# Consider the greedily sampled token\n", + "init_rng, new_rng = jax.random.split(init_rng)\n", + "first_generated_token = inference_utils.sampling(\n", + " selected_logits,\n", + " new_rng,\n", + " config.decode_sampling_strategy, # \"greedy\"\n", + ")\n", + "output = tokenizer.decode([first_generated_token.item()])\n", + "max_logging.log(f\"Next predicted token is `{output}` for the input prompt: `{config.prompt}`.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 1f92c1383..5835e2d2c 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -107,7 +107,7 @@ This command uses the checkpoint from the training run. If no checkpoint is foun In the same TPU VM where you just installed all the dependencies of MaxText, You can also run training and decoding in MaxText via Notebook (for e.g., via Jupyter or Colab). #### Decoding in MaxText via notebook -You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. +You can use [demo_decoding.ipynb](demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. ### Run MaxText on NVIDIA GPUs 1. Use `bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu` to build a container with the required dependencies. From 32e7f27089206ca84208a6d207b4fc7ca14aa34a Mon Sep 17 00:00:00 2001 From: Rob Mulla Date: Tue, 16 Dec 2025 11:38:36 -0500 Subject: [PATCH 4/5] docs: use symlink for notebook and ignore build artifacts --- .gitignore | 1 + docs/tutorials.md | 1 - docs/tutorials/demo_decoding.ipynb | 439 +---------------------- src/MaxText/examples/demo_decoding.ipynb | 5 +- 4 files changed, 6 insertions(+), 440 deletions(-) mode change 100644 => 120000 docs/tutorials/demo_decoding.ipynb diff --git a/.gitignore b/.gitignore index 7881d3a00..3611cec51 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ dmypy.json # Gemini CLI .gemini/ gha-creds-*.json +docs/jupyter_execute/ diff --git a/docs/tutorials.md b/docs/tutorials.md index 2d8a6b096..ff2a17f61 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -40,5 +40,4 @@ Below is a list of all available tutorials organized by topic. tutorials/first_run.md tutorials/pretraining.md tutorials/post_training_index.md -tutorials/demo_decoding.ipynb ``` diff --git a/docs/tutorials/demo_decoding.ipynb b/docs/tutorials/demo_decoding.ipynb deleted file mode 100644 index 9b913318e..000000000 --- a/docs/tutorials/demo_decoding.ipynb +++ /dev/null @@ -1,438 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "e017d77b", - "metadata": {}, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb)\n", - " \n", - "# Qwen3-0.6B Decoding Demo" - ] - }, - { - "cell_type": "markdown", - "id": "dc85cefe-8f29-47db-a8f3-4e8fbb354eb5", - "metadata": {}, - "source": [ - "## Prerequisites" - ] - }, - { - "cell_type": "markdown", - "id": "55e3ce9e-8968-4d68-ba2b-b36c616b52a9", - "metadata": {}, - "source": [ - "### Change Runtime Type\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "id": "bf5e0f3f-5833-4260-a31d-b156249d67ab", - "metadata": {}, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need to paste it in the next step.\n", - "\n", - "**Follow these steps to store your token:**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "markdown", - "id": "8a6deec5-b64a-4bc6-86c4-c24696c66f17", - "metadata": {}, - "source": [ - "## Installation: MaxText & Other Dependencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9b2d4a66-99de-404c-aac3-18b1af4af78e", - "metadata": {}, - "outputs": [], - "source": [ - "# Install uv, a fast Python package installer\n", - "!pip install uv\n", - "\n", - "# Install MaxText and dependencies\n", - "!uv pip install maxtext --resolution=lowest\n", - "!python3 -m MaxText.install_maxtext_extra_deps\n", - "\n", - "# Use nest_asyncio to allow nested event loops in notebooks\n", - "!uv pip install nest_asyncio\n", - "\n", - "# Install the PyTorch library\n", - "!uv pip install torch" - ] - }, - { - "cell_type": "markdown", - "id": "5a07fd61-35b7-4aa9-93cd-49ef89fb550d", - "metadata": {}, - "source": [ - "### Restart Session\n", - "To apply certain changes, you need to restart the session.\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "id": "2f1ebdb1-dcf4-417b-9c29-3461e06aa9cf", - "metadata": {}, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8e986cb", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "import datetime\n", - "import jax\n", - "import os\n", - "import nest_asyncio\n", - "import numpy as np\n", - "\n", - "import MaxText as mt\n", - "from MaxText import common_types\n", - "from MaxText import inference_utils\n", - "from MaxText import maxtext_utils\n", - "from MaxText import max_logging\n", - "from MaxText import pyconfig\n", - "from MaxText.input_pipeline import _input_pipeline_utils\n", - "from MaxText.utils.ckpt_conversion import to_maxtext\n", - "\n", - "from google.colab import userdata\n", - "from huggingface_hub import login\n", - "\n", - "MAXTEXT_PKG_DIR = os.path.dirname(mt.__file__)\n", - "\n", - "nest_asyncio.apply()" - ] - }, - { - "cell_type": "markdown", - "id": "c4f53124", - "metadata": {}, - "source": [ - "## Sanity Test: Checking for Available TPU Devices" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a545acd8", - "metadata": {}, - "outputs": [], - "source": [ - "jax.distributed.initialize() # distributed.initialize should only be called once.\n", - "jax.devices()" - ] - }, - { - "cell_type": "markdown", - "id": "be0113d9-0cb6-45aa-9fa4-7e543db7645e", - "metadata": {}, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b80080fa-473b-4683-b0c9-765af43efd49", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = \"qwen3-0.6b\"\n", - "PROMPT = \"I love to\"\n", - "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", - "MODEL_CHECKPOINT_PATH = f\"/tmp/checkpoints/{MODEL_NAME}/{RUN_NAME}/unscanned\"\n", - "\n", - "HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "login(token=HF_TOKEN)\n", - "max_logging.log(\"Authenticated with Hugging Face successfully!\")" - ] - }, - { - "cell_type": "markdown", - "id": "03ff53b8-b931-4190-bcac-d6ca885cbbc8", - "metadata": {}, - "source": [ - "## Download Model Checkpoint From Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1fd3578c-5763-410e-8b61-72d7415628bd", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "argv = [\n", - " \"\", # This is a placeholder, it's not actually used by the script's logic\n", - " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"base_output_directory={MODEL_CHECKPOINT_PATH}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " \"use_multimodal=false\",\n", - " \"scan_layers=false\",\n", - "]\n", - "\n", - "to_maxtext.main(argv)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "94a9fd37-95e5-4075-837e-de0f1666d55f", - "metadata": {}, - "outputs": [], - "source": [ - "max_logging.log(f\"Model checkpoint can be found at: {MODEL_CHECKPOINT_PATH}/0/items\")" - ] - }, - { - "cell_type": "markdown", - "id": "0cf4bbe4-6485-4ce7-8aef-cf3df3810e52", - "metadata": {}, - "source": [ - "## Initialize Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32f44079-87ae-4ed4-a008-c0dbb8aaf8c0", - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "config = pyconfig.initialize(\n", - " [\"\", f\"{MAXTEXT_PKG_DIR}/configs/base.yml\"],\n", - " per_device_batch_size=1.0,\n", - " run_name=\"test\",\n", - " max_target_length=4,\n", - " max_prefill_predict_length=4,\n", - " tokenizer_path=f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n", - " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}/0/items\",\n", - " model_name=MODEL_NAME,\n", - " async_checkpointing=False,\n", - " prompt=PROMPT,\n", - " scan_layers=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a637f43b-6fcc-4305-af5b-d9d30d464bb6", - "metadata": {}, - "outputs": [], - "source": [ - "max_logging.log(\"Decode configurations initialized.\")" - ] - }, - { - "cell_type": "markdown", - "id": "cd502094-1694-410a-91e9-25bbb8dfb33a", - "metadata": {}, - "source": [ - "## Initialize Decode State" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2d2de93", - "metadata": {}, - "outputs": [], - "source": [ - "model = mt.from_config(config)\n", - "mesh = model.mesh\n", - "init_rng = jax.random.PRNGKey(config.init_weights_seed)\n", - "state, _ = maxtext_utils.setup_decode_state(model, config, init_rng, mesh, None)\n", - "max_logging.log(\"Decode state initialized.\")" - ] - }, - { - "cell_type": "markdown", - "id": "ed4b59a7", - "metadata": {}, - "source": [ - "## Get Tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35584129-3c45-45ad-b2a2-a56f98d27f06", - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = _input_pipeline_utils.get_tokenizer(\n", - " f\"{MAXTEXT_PKG_DIR}/assets/qwen3-tokenizer\",\n", - " \"huggingface\",\n", - " add_bos=True,\n", - " add_eos=False,\n", - ")\n", - "max_logging.log(\"Tokenizer loaded succuessfully.\")" - ] - }, - { - "cell_type": "markdown", - "id": "32a252ae", - "metadata": {}, - "source": [ - "## Prepare Inputs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2d2d0c5", - "metadata": {}, - "outputs": [], - "source": [ - "input_ids = tokenizer.encode(config.prompt)\n", - "\n", - "# Pad input_ids to max_target_length\n", - "padded_ids = np.zeros(config.max_target_length, dtype=np.int32)\n", - "padded_ids[: len(input_ids)] = input_ids\n", - "ids = np.asarray(padded_ids, dtype=np.int32)\n", - "\n", - "s = (config.global_batch_size_to_train_on, config.max_target_length)\n", - "decoder_segment_ids = np.zeros(s) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR\n", - "decoder_positions = np.stack(\n", - " [np.arange(config.max_target_length, dtype=np.int32) for _ in range(config.global_batch_size_to_train_on)]\n", - ")\n", - "\n", - "ids = np.stack([ids for _ in range(config.global_batch_size_to_train_on)])\n", - "max_logging.log(\n", - " f\"input_ids={input_ids}, \\n\\nids={ids}, \\n\\ndecoder_segment_ids = {decoder_segment_ids}, \\n\\ndecoder_positions= {decoder_positions}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "647018c1", - "metadata": {}, - "source": [ - "## Run Forward Pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7436751b", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "full_train_logits = model.apply(\n", - " state.params,\n", - " ids,\n", - " decoder_positions,\n", - " decoder_segment_ids,\n", - " enable_dropout=False,\n", - " rngs={\"aqt\": init_rng},\n", - ")\n", - "full_train_logits = jax.experimental.multihost_utils.process_allgather(full_train_logits)\n", - "max_logging.log(f\"{full_train_logits[0, 0, :]=}\")" - ] - }, - { - "cell_type": "markdown", - "id": "5640ab55", - "metadata": {}, - "source": [ - "## Generate Text with Greedy Decoding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb06c0c9", - "metadata": {}, - "outputs": [], - "source": [ - "selected_logits = jax.lax.dynamic_slice(\n", - " full_train_logits, (0, 0, full_train_logits.shape[2] - 2, 0), (1, 1, 1, full_train_logits.shape[3])\n", - ")\n", - "\n", - "# Consider the greedily sampled token\n", - "init_rng, new_rng = jax.random.split(init_rng)\n", - "first_generated_token = inference_utils.sampling(\n", - " selected_logits,\n", - " new_rng,\n", - " config.decode_sampling_strategy, # \"greedy\"\n", - ")\n", - "output = tokenizer.decode([first_generated_token.item()])\n", - "max_logging.log(f\"Next predicted token is `{output}` for the input prompt: `{config.prompt}`.\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/demo_decoding.ipynb b/docs/tutorials/demo_decoding.ipynb new file mode 120000 index 000000000..01d1e9781 --- /dev/null +++ b/docs/tutorials/demo_decoding.ipynb @@ -0,0 +1 @@ +../../src/MaxText/examples/demo_decoding.ipynb \ No newline at end of file diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/MaxText/examples/demo_decoding.ipynb index 9b913318e..a6a88c094 100644 --- a/src/MaxText/examples/demo_decoding.ipynb +++ b/src/MaxText/examples/demo_decoding.ipynb @@ -431,8 +431,11 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" + }, + "myst": { + "orphan": true } }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file From 0219497c7a4dc613ab5c69a4945d5842ececd670 Mon Sep 17 00:00:00 2001 From: Rob Mulla Date: Tue, 16 Dec 2025 11:43:36 -0500 Subject: [PATCH 5/5] docs: nest notebook under first_run and fix warnings --- docs/conf.py | 2 ++ docs/tutorials/first_run.md | 9 ++++++++- src/MaxText/examples/demo_decoding.ipynb | 3 --- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index fbc2ece3d..a095c50b1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -79,4 +79,6 @@ "run_maxtext/run_maxtext_via_multihost_job.md", "run_maxtext/run_maxtext_via_multihost_runner.md", "reference/core_concepts/llm_calculator.ipynb", + "jupyter_execute", + "_build", ] diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 5835e2d2c..262fb6028 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -107,7 +107,14 @@ This command uses the checkpoint from the training run. If no checkpoint is foun In the same TPU VM where you just installed all the dependencies of MaxText, You can also run training and decoding in MaxText via Notebook (for e.g., via Jupyter or Colab). #### Decoding in MaxText via notebook -You can use [demo_decoding.ipynb](demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. +You can use [demo_decoding.ipynb](demo_decoding) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. + +```{toctree} +:maxdepth: 1 +:hidden: + +demo_decoding.ipynb +``` ### Run MaxText on NVIDIA GPUs 1. Use `bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu` to build a container with the required dependencies. diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/MaxText/examples/demo_decoding.ipynb index a6a88c094..965eb7481 100644 --- a/src/MaxText/examples/demo_decoding.ipynb +++ b/src/MaxText/examples/demo_decoding.ipynb @@ -431,9 +431,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.11" - }, - "myst": { - "orphan": true } }, "nbformat": 4,