Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .github/workflows/check_docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@ jobs:
with:
persist-credentials: false

- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
- name: Install uv and set the Python version
uses: astral-sh/setup-uv@eb1897b8dc4b5d5bfe39a428a8f2304605e0983c # v7.0.0
with:
python-version: '3.12'
cache: 'pip' # caching pip dependencies
enable-cache: true

- name: Set venv
run: uv venv --python 3.12 $GITHUB_WORKSPACE/venv

- name: Install dependencies
run: pip install -r dependencies/requirements/requirements_docs.txt
run: . $GITHUB_WORKSPACE/venv/bin/activate && uv pip install -r dependencies/requirements/requirements_docs.txt

- name: Build documentation
run: |
. $GITHUB_WORKSPACE/venv/bin/activate
uv pip install -e . --no-deps
sphinx-build -W -b html docs docs/_build/html
126 changes: 126 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
https://www.sphinx-doc.org/en/master/usage/configuration.html
"""

import os
import sys
import logging
from sphinx.util import logging as sphinx_logging

MAXTEXT_REPO_ROOT = os.environ.get(
"MAXTEXT_REPO_ROOT", os.path.join(os.path.dirname(os.path.dirname(__file__)))
)
sys.path.insert(0, os.path.abspath(os.path.join(MAXTEXT_REPO_ROOT, "src")))

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

Expand All @@ -37,6 +47,10 @@
"myst_nb",
"sphinx_design",
"sphinx_copybutton",
"sphinx.ext.napoleon",
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.viewcode",
]

templates_path = ["_templates"]
Expand Down Expand Up @@ -79,4 +93,116 @@
"run_maxtext/run_maxtext_via_multihost_job.md",
"run_maxtext/run_maxtext_via_multihost_runner.md",
"reference/core_concepts/llm_calculator.ipynb",
"reference/api_generated/modules.rst",
"reference/api_generated/install_maxtext_extra_deps.rst",
"reference/api_generated/install_maxtext_extra_deps.install_github_deps.rst",
]

autosummary_generate = True
autodoc_typehints = "description"
autodoc_member_order = "bysource"
autodoc_mock_imports = ["jetstream", "vllm", "torch", "tensorflow_datasets", "tpu_inference"]

# Suppress specific warnings
suppress_warnings = [
"autodoc.import_object",
]

# -- Autogenerate API documentation ------------------------------------------
def run_apidoc(_):
"""Runs sphinx-apidoc to generate API documentation.

This function is connected to the Sphinx build process and is triggered to
automatically generate the reStructuredText (RST) files for the API
documentation from the docstrings in the MaxText source code.

Args:
_: The Sphinx application object. Not used.
"""
# directly within the Sphinx process, especially on macOS, as it avoids
# potential multiprocessing/forking issues like the "mutex lock failed" error.
# pylint: disable=import-outside-toplevel
import subprocess

os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "1"

assert os.path.isfile(os.path.join(MAXTEXT_REPO_ROOT, "pyproject.toml"))

# The path where the generated RST files will be stored
output_path = os.path.join(MAXTEXT_REPO_ROOT, "docs", "reference", "api_generated")

# Command to run sphinx-apidoc
# Note: We use `sys.executable -m sphinx.ext.apidoc` to ensure we're using
# the apidoc from the same Python environment as Sphinx.
command = [
sys.executable,
"-m",
"sphinx.ext.apidoc",
"--module-first",
"--force",
"--separate",
"--output-dir",
output_path,
os.path.join(MAXTEXT_REPO_ROOT, "src"),
# Paths to exclude
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_mlperf"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "scratch_code"),
# Paths to exclude due to import errors
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "decode_multi.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "offline_engine.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "benchmark_chunked_prefill.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "decode.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark_sweep.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "load_and_quantize_checkpoint.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_config.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_server.py"),
os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "prefill_packing.py"),
]

# Run the command and check for errors
try:
print("Running sphinx-apidoc...")
subprocess.check_call(
command, env={**os.environ, **{"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "1"}}
)
except subprocess.CalledProcessError as e:
print(f"sphinx-apidoc failed with error: {e}", file=sys.stderr)
sys.exit(1)


class FilterSphinxWarnings(logging.Filter):
"""Filter autosummary 'duplicate object description' warnings.

These warnings are unnecessary as they do not cause missing documentation
or rendering issues, so it is safe to filter them out.
"""

def __init__(self, app):
self.app = app
super().__init__()

def filter(self, record: logging.LogRecord) -> bool:
msg = record.getMessage()
filter_out = ("descrição duplicada de objeto",)
return not msg.strip().startswith(filter_out)


def setup(app):
"""Set up the Sphinx application with custom behavior."""

# Connect the apidoc generation to the Sphinx build process
run_apidoc(None)
print("running:", app)

# Set up custom logging filters
logger = logging.getLogger("sphinx")
warning_handler, *_ = [
h
for h in logger.handlers
if isinstance(h, sphinx_logging.WarningStreamHandler)
]
warning_handler.filters.insert(0, FilterSphinxWarnings(app))
2 changes: 1 addition & 1 deletion docs/guides/optimization/benchmark_and_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Different quantization recipes are available, including` "int8", "fp8", "fp8_ful

For v6e and earlier generation TPUs, use the "int8" recipe. For v7x and later generation TPUs, use "fp8_full". GPUs should use “fp8_gpu” for NVIDIA and "nanoo_fp8" for AMD.

See [](quantization).
See [](quantization-doc).

### Choose sharding strategy

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/optimization/custom_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Use these general runtime configurations to improve your model's performance.

## Step 3. Choose efficient sharding strategies using Roofline Analysis

To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/).
To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding_on_TPUs) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/).

| TPU Type | ICI Arithmetic Intensity |
|---|---|
Expand Down
6 changes: 6 additions & 0 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ Key concepts including checkpointing strategies, quantization, tiling, and Mixtu
:::
::::

## 📚 API Reference

Find comprehensive API documentation for MaxText modules, classes, and functions in the [API Reference page](reference/api_reference).


```{toctree}
:hidden:
:maxdepth: 1
Expand All @@ -58,4 +63,5 @@ reference/performance_metrics.md
reference/models.md
reference/architecture.md
reference/core_concepts.md
reference/api_reference.rst
```
26 changes: 26 additions & 0 deletions docs/reference/api_reference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
..
Copyright 2024 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


API Reference
=============

This section contains the complete API documentation for the ``MaxText`` library, automatically generated from the source code docstrings.

.. toctree::
:maxdepth: 4
:caption: Package Modules
:glob:

api_generated/MaxText
4 changes: 2 additions & 2 deletions docs/reference/core_concepts/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
See the License for the specific language governing permissions and
limitations under the License.
-->
(quantization)=
(quantization-doc)=
# Quantization

Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`).

MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT).

## Why use quantization?
## Why use quantization?

The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages:

Expand Down
12 changes: 7 additions & 5 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,13 @@ def load_state_if_possible(
enable_orbax_v1: bool flag for enabling Orbax v1.
checkpoint_conversion_fn: function for converting checkpoint to Orbax v1.
source_checkpoint_layout: Optional checkpoint context to use for loading,
provided in string format with the default being "orbax".
provided in string format with the default being "orbax".

Returns:
A tuple of (train_state, train_state_params) where full_train_state captures
a full reload and train_state_params just the params for a partial reload.
At most one will be non-None. Both can be None if neither checkpoint is
set.
A tuple of `(train_state, train_state_params)`
where full_train_state captures a full reload and `train_state_params`
just the params for a partial reload. At most one will be non-None. Both
can be None if neither checkpoint is set.
"""

if checkpoint_manager is not None:
Expand Down Expand Up @@ -615,8 +615,10 @@ def map_to_pspec(data):

def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-error
"""Setup checkpoint logger.

Args:
config

Returns:
CloudLogger
"""
Expand Down
20 changes: 11 additions & 9 deletions src/MaxText/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
max_pdb: The maximum per_device_batch_size to test.

Returns:
The largest per_device_batch_size within the range that does not result in an OOM error.
The largest `per_device_batch_size` within the range that does not result in an OOM error.
"""
print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.")

Expand Down Expand Up @@ -345,8 +345,9 @@ def get_parameter_value(config_tuple, prefix):

Returns:
A tuple of (bool, str or None).
- (True, value) if the prefix is found.
- (False, None) if the prefix is not found.

* (True, value) if the prefix is found.
* (False, None) if the prefix is not found.
"""
for item in config_tuple:
if item.startswith(prefix):
Expand All @@ -364,12 +365,13 @@ def find_batch_size(base_argv):
Parses the base arguments to find the 'per_device_batch_size'.

Args:
base_argv: The tuple of command-line arguments.
base_argv: The tuple of command-line arguments.

Returns:
A tuple of (bool, int or None):
- (True, batch_size) if 'per_device_batch_size=...' was found.
- (False, None) if it was not found.
A tuple of (bool, int or None)

* (True, batch_size) if `per_device_batch_size=...` was found.
* (False, None) if it was not found.
"""
pdb_provided, pdb_str = get_parameter_value(base_argv, prefix="per_device_batch_size=")

Expand All @@ -384,10 +386,10 @@ def find_remat_policy_tensor_names(base_argv):
to be considered for rematerialization.

Args:
base_argv: The tuple of command-line arguments.
base_argv: The tuple of command-line arguments.

Returns:
A list of tensor names that were passed as flags.
A list of tensor names that were passed as flags.
"""
full_tensor_list = [
"context",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ def GEMMA2_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace weights path and weights shape.

Args:
config (dict): Model configuration dictionary, defined in `model_configs.py`
config (dict): Model configuration dictionary, defined in `model_configs.py`

Returns:
dict: A mapping where:
- Keys are HuggingFace model parameter paths
- Values are parameter shape as a List
dict: A mapping where:

* Keys are HuggingFace model parameter paths
* Values are parameter shape as a List
"""

mapping = {
Expand Down
Loading
Loading