From 329ed7b7523b61a2cfeec33511f25d2160f530f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Thu, 18 Dec 2025 18:24:44 -0300 Subject: [PATCH 1/3] Add automatic API documentation generation Co-authored-by: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> --- docs/conf.py | 124 +++++++++++ .../optimization/benchmark_and_performance.md | 2 +- docs/guides/optimization/custom_model.md | 2 +- docs/reference.md | 6 + docs/reference/api_reference.rst | 26 +++ docs/reference/core_concepts/quantization.md | 4 +- src/MaxText/generate_param_only_checkpoint.py | 16 +- src/MaxText/inference/kvcache.py | 15 +- src/MaxText/inference/offline_engine.py | 31 +-- src/MaxText/inference/page_manager.py | 42 ++-- .../vllm/maxtext_vllm_adapter/adapter.py | 8 +- src/MaxText/integration/vllm/setup.py | 5 +- src/MaxText/kernels/jax_flash_attention.py | 6 +- src/MaxText/kernels/megablox/backend.py | 24 ++- src/MaxText/layers/attention_op.py | 125 ++++++----- src/MaxText/layers/moe.py | 36 ++-- src/MaxText/layers/normalizations.py | 10 +- src/MaxText/layers/pipeline.py | 26 ++- src/MaxText/maxtext_utils.py | 42 ++-- src/MaxText/multimodal_utils.py | 200 ++++++++++-------- src/MaxText/pyconfig_deprecated.py | 8 +- src/MaxText/sequence_packing.py | 40 ++-- src/MaxText/sft/sft_trainer.py | 7 +- src/MaxText/tokenizer.py | 9 +- .../utils/ckpt_conversion/utils/hf_shape.py | 16 +- .../utils/ckpt_conversion/utils/utils.py | 11 +- 26 files changed, 533 insertions(+), 308 deletions(-) create mode 100644 docs/reference/api_reference.rst diff --git a/docs/conf.py b/docs/conf.py index 646b60f86e..8d633d16eb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,6 +22,16 @@ https://www.sphinx-doc.org/en/master/usage/configuration.html """ +import os +import sys +import logging +from sphinx.util import logging as sphinx_logging + +MAXTEXT_REPO_ROOT = os.environ.get( + "MAXTEXT_REPO_ROOT", os.path.join(os.path.dirname(os.path.dirname(__file__))) +) +sys.path.insert(0, os.path.abspath(os.path.join(MAXTEXT_REPO_ROOT, "src"))) + # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information @@ -37,6 +47,10 @@ "myst_nb", "sphinx_design", "sphinx_copybutton", + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", ] templates_path = ["_templates"] @@ -79,4 +93,114 @@ "run_maxtext/run_maxtext_via_multihost_job.md", "run_maxtext/run_maxtext_via_multihost_runner.md", "reference/core_concepts/llm_calculator.ipynb", + "reference/api_generated/modules.rst", + "reference/api_generated/install_maxtext_extra_deps.rst", + "reference/api_generated/install_maxtext_extra_deps.install_github_deps.rst", +] + +autosummary_generate = True +autodoc_typehints = "description" +autodoc_member_order = "bysource" +autodoc_mock_imports = ["jetstream", "vllm", "torch", "tensorflow_datasets", "tpu_inference"] + +# Suppress specific warnings +suppress_warnings = [ + "autodoc.import_object", ] + +# -- Autogenerate API documentation ------------------------------------------ +def run_apidoc(_): + """Runs sphinx-apidoc to generate API documentation. + + This function is connected to the Sphinx build process and is triggered to + automatically generate the reStructuredText (RST) files for the API + documentation from the docstrings in the MaxText source code. + + Args: + _: The Sphinx application object. Not used. + """ + # directly within the Sphinx process, especially on macOS, as it avoids + # potential multiprocessing/forking issues like the "mutex lock failed" error. + # pylint: disable=import-outside-toplevel + import subprocess + + os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "1" + + assert os.path.isfile(os.path.join(MAXTEXT_REPO_ROOT, "pyproject.toml")) + + # The path where the generated RST files will be stored + output_path = os.path.join(MAXTEXT_REPO_ROOT, "docs", "reference", "api_generated") + + # Command to run sphinx-apidoc + # Note: We use `sys.executable -m sphinx.ext.apidoc` to ensure we're using + # the apidoc from the same Python environment as Sphinx. + command = [ + sys.executable, + "-m", + "sphinx.ext.apidoc", + "--module-first", + "--force", + "--separate", + "--output-dir", + output_path, + os.path.join(MAXTEXT_REPO_ROOT, "src"), + # Paths to exclude + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_mlperf"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "scratch_code"), + # Paths to exclude due to import errors + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "decode_multi.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "offline_engine.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "benchmark_chunked_prefill.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "decode.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark_sweep.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "load_and_quantize_checkpoint.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_config.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_server.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "prefill_packing.py"), + ] + + # Run the command and check for errors + try: + print("Running sphinx-apidoc...") + subprocess.check_call( + command, env={**os.environ, **{"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "1"}} + ) + except subprocess.CalledProcessError as e: + print(f"sphinx-apidoc failed with error: {e}", file=sys.stderr) + sys.exit(1) + + +class FilterSphinxWarnings(logging.Filter): + """Filter autosummary 'duplicate object description' warnings. + + These warnings are unnecessary as they do not cause missing documentation + or rendering issues, so it is safe to filter them out. + """ + + def __init__(self, app): + self.app = app + super().__init__() + + def filter(self, record: logging.LogRecord) -> bool: + msg = record.getMessage() + filter_out = ("descrição duplicada de objeto",) + return not msg.strip().startswith(filter_out) + + +def setup(app): + # Connect the apidoc generation to the Sphinx build process + run_apidoc(None) + print("running:", app) + + # Set up custom logging filters + logger = logging.getLogger("sphinx") + warning_handler, *_ = [ + h + for h in logger.handlers + if isinstance(h, sphinx_logging.WarningStreamHandler) + ] + warning_handler.filters.insert(0, FilterSphinxWarnings(app)) diff --git a/docs/guides/optimization/benchmark_and_performance.md b/docs/guides/optimization/benchmark_and_performance.md index f0d1b15433..25626065b9 100644 --- a/docs/guides/optimization/benchmark_and_performance.md +++ b/docs/guides/optimization/benchmark_and_performance.md @@ -69,7 +69,7 @@ Different quantization recipes are available, including` "int8", "fp8", "fp8_ful For v6e and earlier generation TPUs, use the "int8" recipe. For v7x and later generation TPUs, use "fp8_full". GPUs should use “fp8_gpu” for NVIDIA and "nanoo_fp8" for AMD. -See [](quantization). +See [](quantization-doc). ### Choose sharding strategy diff --git a/docs/guides/optimization/custom_model.md b/docs/guides/optimization/custom_model.md index 962df428ad..407d99f6f1 100644 --- a/docs/guides/optimization/custom_model.md +++ b/docs/guides/optimization/custom_model.md @@ -83,7 +83,7 @@ Use these general runtime configurations to improve your model's performance. ## Step 3. Choose efficient sharding strategies using Roofline Analysis -To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/). +To achieve good performance, it's often necessary to co-design the model's dimensions (like the MLP dimension) along with the sharding strategy. We have included examples for [v5p](https://docs.cloud.google.com/tpu/docs/v5p), [Trillium](https://docs.cloud.google.com/tpu/docs/v6e), and [Ironwood](https://docs.cloud.google.com/tpu/docs/tpu7x) that demonstrate which sharding approaches work well for specific models. We recommend reading [](sharding_on_TPUs) and Jax’s [scaling book](https://jax-ml.github.io/scaling-book/sharding/). | TPU Type | ICI Arithmetic Intensity | |---|---| diff --git a/docs/reference.md b/docs/reference.md index 990ed3ed8f..46af77b63b 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -50,6 +50,11 @@ Key concepts including checkpointing strategies, quantization, tiling, and Mixtu ::: :::: +## 📚 API Reference + +Find comprehensive API documentation for MaxText modules, classes, and functions in the [API Reference page](reference/api_reference). + + ```{toctree} :hidden: :maxdepth: 1 @@ -58,4 +63,5 @@ reference/performance_metrics.md reference/models.md reference/architecture.md reference/core_concepts.md +reference/api_reference.rst ``` diff --git a/docs/reference/api_reference.rst b/docs/reference/api_reference.rst new file mode 100644 index 0000000000..6b7d39739a --- /dev/null +++ b/docs/reference/api_reference.rst @@ -0,0 +1,26 @@ +.. + Copyright 2024 Google LLC + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +API Reference +============= + +This section contains the complete API documentation for the ``MaxText`` library, automatically generated from the source code docstrings. + +.. toctree:: + :maxdepth: 4 + :caption: Package Modules + :glob: + + api_generated/MaxText diff --git a/docs/reference/core_concepts/quantization.md b/docs/reference/core_concepts/quantization.md index 701564b68d..b494125ec9 100644 --- a/docs/reference/core_concepts/quantization.md +++ b/docs/reference/core_concepts/quantization.md @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. --> -(quantization)= +(quantization-doc)= # Quantization Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`). MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT). -## Why use quantization? +## Why use quantization? The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages: diff --git a/src/MaxText/generate_param_only_checkpoint.py b/src/MaxText/generate_param_only_checkpoint.py index 456d416a9c..e2a4ffee63 100644 --- a/src/MaxText/generate_param_only_checkpoint.py +++ b/src/MaxText/generate_param_only_checkpoint.py @@ -14,11 +14,14 @@ # pylint: disable=g-bad-todo, abstract-method, consider-using-with """Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state. - This typically used for turning a state output by training.py into a state than can be consumed by decode.py. - The input "fullstate" is passed in via: - load_full_state_path. - The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. +This typically used for turning a state output by training.py into a state than can be consumed by decode.py. + +The input "fullstate" is passed in via:: + + load_full_state_path. + +The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. """ import os.path @@ -157,8 +160,9 @@ def _save_decode_checkpoint(config, state, checkpoint_manager): def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. - - Training checkpoint is loaded from config.load_full_state_path. - - Inference checkpoint will be saved at the config's checkpoint directory. + + * Training checkpoint is loaded from config.load_full_state_path. + * Inference checkpoint will be saved at the config's checkpoint directory. """ devices_array = maxtext_utils.create_device_mesh(config) diff --git a/src/MaxText/inference/kvcache.py b/src/MaxText/inference/kvcache.py index 54cdb00a31..366d3d13be 100644 --- a/src/MaxText/inference/kvcache.py +++ b/src/MaxText/inference/kvcache.py @@ -638,14 +638,14 @@ def update_ar_key_value( """Adds a single token's results to the ar kv cache Args: - one_token_key (Array): Key of one token to add to the cache - one_token_value (Array): Value of one token to add to the cache - cached_ar_key (tuple[nnx.Cache, nnx.Cache|None],): Cached keys to add new token key to, possibly with scale - cached_ar_value (tuple[nnx.Cache, nnx.Cache|None],: Cached values to add new token value to, possible with scale - one_hot_indices (Array): Location of the new token within the cache + one_token_key (Array): Key of one token to add to the cache + one_token_value (Array): Value of one token to add to the cache + cached_ar_key (tuple[nnx.Cache, nnx.Cache|None],): Cached keys to add new token key to, possibly with scale + cached_ar_value (tuple[nnx.Cache, nnx.Cache|None],): Cached values to add new token value to, possible with scale + one_hot_indices (Array): Location of the new token within the cache Returns: - tuple[Array, Array]: Updated caches for key and value with new token info added + tuple[Array, Array]: Updated caches for key and value with new token info added """ cached_key, cached_key_scale = key_caches @@ -758,7 +758,8 @@ def kv_cache_autoregressive( decoder_segment_ids: [b, 1] -- marking segment ids for tokens Returns: - tuple of (key, value, segment_id) for both prefill and ar cache, + tuple of (key, value, segment_id) for both prefill and ar cache + Raises: ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. """ diff --git a/src/MaxText/inference/offline_engine.py b/src/MaxText/inference/offline_engine.py index 255b2a0791..e04067056b 100644 --- a/src/MaxText/inference/offline_engine.py +++ b/src/MaxText/inference/offline_engine.py @@ -15,24 +15,25 @@ """ Offline Inference Engine -Example usage: - offline_engine = OfflineEngine( - config=maxtext_config, - params=None, - enable_batch_prefill=True, - ) +Example usage:: + + offline_engine = OfflineEngine( + config=maxtext_config, + params=None, + enable_batch_prefill=True, + ) - input_data = [ - jax.numpy.arange(80), - jax.numpy.arange(90), - jax.numpy.arange(100), - ] + input_data = [ + jax.numpy.arange(80), + jax.numpy.arange(90), + jax.numpy.arange(100), + ] - results = offline_engine.batch_inference(input_data) + results = offline_engine.batch_inference(input_data) - for completion_output in results: - text = offline_engine.tokenizer.decode(completion_output.token_ids) - max_logging.log(f"Output: {text}") + for completion_output in results: + text = offline_engine.tokenizer.decode(completion_output.token_ids) + max_logging.log(f"Output: {text}") """ import os diff --git a/src/MaxText/inference/page_manager.py b/src/MaxText/inference/page_manager.py index 71f29c40ef..fc0cf2684e 100644 --- a/src/MaxText/inference/page_manager.py +++ b/src/MaxText/inference/page_manager.py @@ -225,16 +225,16 @@ def _reserve_pages_for_group( `released_state` unchanged (effectively leaving the group empty). Args: - released_state: The global `PageState` after pages for `page_group_id` - have already been released. - page_group_id: The index of the page group to allocate pages for. - true_length: The target sequence length for the prefill. MUST BE > 0. - tokens_per_page: The capacity of each page. - max_pages_per_group: The maximum number of pages the group can hold. + released_state: The global `PageState` after pages for `page_group_id` + have already been released. + page_group_id: The index of the page group to allocate pages for. + true_length: The target sequence length for the prefill. MUST BE > 0. + tokens_per_page: The capacity of each page. + max_pages_per_group: The maximum number of pages the group can hold. Returns: - A new `PageState` with pages allocated for the group and its state updated, - or the input `released_state` if allocation failed due to resource limits. + A new `PageState` with pages allocated for the group and its state updated, + or the input `released_state` if allocation failed due to resource limits. """ num_pages_needed = (true_length + tokens_per_page - 1) // tokens_per_page last_token_abs_idx = true_length - 1 @@ -336,6 +336,7 @@ def _update_decode_pages_global( """Updates pages globally for one step of autoregressive decoding. This function performs the following steps for all page groups simultaneously: + 1. Increments `sequence_lengths` for groups marked as `has_active_page`. 2. Calculates the new `active_page_position` based on the incremented length. 3. Determines which active groups now require a new page because their sequence @@ -420,8 +421,8 @@ class PageManager: the `PageState`. It uses the concept of page groups, where each group typically corresponds to a single request or sequence being processed. - Example: - ```python + Example:: + # Initialize a PageManager from configuration config = YourConfig(...) # Set pagedattn_num_pages, etc. page_manager = PageManager(config) @@ -444,7 +445,6 @@ class PageManager: page_state=state, page_group_id=0 ) - ``` """ def __init__(self, config: Config): @@ -515,15 +515,14 @@ def update_prefill_pages(self, page_state: PageState, page_group_id: int, true_l Raises: ValueError: If `page_group_id` or `true_length` are outside their valid ranges. - Example: - ```python + Example:: + # Reserve pages for a 16-token sequence in group 0 state = page_manager.update_prefill_pages( page_state=state, page_group_id=0, true_length=16 ) - ``` """ if page_group_id < 0 or page_group_id >= self.max_page_groups: raise ValueError(f"PageManager: page_group_id ({page_group_id}) out of range [0, {self.max_page_groups})") @@ -553,11 +552,10 @@ def update_decode_pages(self, page_state: PageState) -> PageState: Groups that required and successfully obtained a new page will have their `num_pages_used`, `page_map`, and `active_page` updated. - Example: - ```python + Example:: + # Advance state for all active sequences by one decode step state = page_manager.update_decode_pages(state) - ``` """ return _update_decode_pages_global(page_state, self.tokens_per_page, self.max_pages_per_group) @@ -583,14 +581,13 @@ def release_pages(self, page_state: PageState, page_group_id: int) -> PageState: Raises: ValueError: If `page_group_id` is outside its valid range. - Example: - ```python + Example:: + # Release all pages currently held by group 0 state = page_manager.release_pages( page_state=state, page_group_id=0 ) - ``` """ if page_group_id < 0 or page_group_id >= self.max_page_groups: raise ValueError(f"PageManager: page_group_id ({page_group_id}) out of range [0, {self.max_page_groups})") @@ -607,11 +604,10 @@ def get_initial_page_state(self) -> PageState: An initialized `PageState` object where all pages are free (except possibly 0) and no groups are active. - Example: - ```python + Example:: + # Get a fresh, empty page state initial_state = page_manager.get_initial_page_state() - ``` """ return initialize_page_state( num_pages=self.num_pages, diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index 9899ef9416..271728255e 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -27,7 +27,13 @@ from MaxText.common_types import MODEL_MODE_AUTOREGRESSIVE from MaxText.globals import MAXTEXT_PKG_DIR -from tpu_inference.layers.common.attention_metadata import AttentionMetadata +try: + from tpu_inference.layers.common.attention_metadata import AttentionMetadata +except ImportError: + # Mock for documentation build or environments without tpu_inference + class AttentionMetadata: + input_positions: jax.Array + from vllm.config import VllmConfig diff --git a/src/MaxText/integration/vllm/setup.py b/src/MaxText/integration/vllm/setup.py index 2fc41b6e35..9a71def5b3 100644 --- a/src/MaxText/integration/vllm/setup.py +++ b/src/MaxText/integration/vllm/setup.py @@ -16,9 +16,10 @@ from setuptools import setup -setup( +if __name__ == "__main__": + setup( name="maxtext_vllm_adapter", version="0.1.0", packages=["maxtext_vllm_adapter"], entry_points={"vllm.general_plugins": ["register_maxtext_vllm_adapter = maxtext_vllm_adapter:register"]}, -) + ) diff --git a/src/MaxText/kernels/jax_flash_attention.py b/src/MaxText/kernels/jax_flash_attention.py index 8c89bd001c..3ab8d242a1 100644 --- a/src/MaxText/kernels/jax_flash_attention.py +++ b/src/MaxText/kernels/jax_flash_attention.py @@ -68,8 +68,10 @@ def flash_attention_block_masked( Returns: If save_residuals is True, returns a tuple containing: - - The output of the attention computation. - - A dict of (logsumexp, max_logits) + + * The output of the attention computation. + * A dict of (logsumexp, max_logits) + Otherwise, returns the output of the attention computation. """ batch_size, num_q_heads, q_seq_len, qk_head_dim_size = q.shape diff --git a/src/MaxText/kernels/megablox/backend.py b/src/MaxText/kernels/megablox/backend.py index c35fab8f9f..53afd99f41 100644 --- a/src/MaxText/kernels/megablox/backend.py +++ b/src/MaxText/kernels/megablox/backend.py @@ -93,16 +93,18 @@ def make_group_metadata( the output for each group. Returns: - tuple of: - group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 - dtype. group_offsets[i] indicates the row at which group [i] starts in - the lhs matrix and group_offsets[i-1] = m. - group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and - jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will - work on. - m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and - jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' - will work on. + tuple of + + * group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + * group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + * m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. """ num_groups = group_sizes.shape[0] @@ -595,7 +597,7 @@ def tgmm( testing and debugging. Returns: - A 3d, jnp.ndarray with shape [num_groups, k, n]. + A 3d, jnp.ndarray with shape [num_groups, k, n]. """ if group_offset is None: group_offset = jnp.array([0], dtype=jnp.int32) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index d136f975cd..a637034250 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -113,9 +113,9 @@ def apply_mask_to_logits(logits: Array, mask: Array): The mask is represented as a tensor with some dtype where 0 represents true and values below a large negative number (here set to - get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask + `get_large_negative_number(logits.dtype) / 2)` represent false. Applying the mask leaves the logits alone in the true case and replaces them by - get_large_negative_number(logits.dtype) in the false case. Previously, this was + `get_large_negative_number(logits.dtype)` in the false case. Previously, this was done by adding the logits to the mask; however, this leads to a bad fusion decision in the compiler that saves the values in memory rather than just the predicate. This implementation avoids that problem. @@ -219,7 +219,7 @@ def _generate_chunk_attention_mask(mask_shape: tuple[int, int], chunk_size: int, within the same chunk, and causally within that chunk). Args: - mask_shape: The desired shape of the mask (q_seq_len, kv_seq_len). + mask_shape: The desired shape of the mask `(q_seq_len, kv_seq_len)`. chunk_size: The size of the attention chunks. Returns: @@ -227,7 +227,7 @@ def _generate_chunk_attention_mask(mask_shape: tuple[int, int], chunk_size: int, allowed according to chunked causal rules, and False otherwise. Raises: - ValueError: If chunk_window_size is None or not positive. + ValueError: If `chunk_window_size` is None or not positive. """ row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + q_offset @@ -245,10 +245,10 @@ def _make_block_mask_indices(bidirectional_mask): """Creates block mask identifying segments based on a bidirectional mask. Args: - bidirectional_mask: boolean mask, e.g. [011110011010]. + bidirectional_mask: boolean mask, e.g. `[011110011010]`. Returns: - block mask for segments, e.g. [011110022030]. + block mask for segments, e.g. `[011110022030]`. """ # Left pad 0. padded_mask = jnp.pad(bidirectional_mask, [(0, 0), (1, 0)], constant_values=0) @@ -259,18 +259,21 @@ def _make_block_mask_indices(bidirectional_mask): def _make_bidirectional_block_mask(bidirectional_mask): """Creates bidirectional block mask from bidirectional_mask, where True corresponds to image tokens. - bidirectional_mask shape: [B, L] - bidirectional_block_mask shape: [B, L, L] - Examples: - bidirectional_mask = [[0, 1, 1, 1, 0, 0]] - bidirectional_block_mask = [[ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ]] + + `bidirectional_mask` shape: [B, L] + `bidirectional_block_mask` shape: [B, L, L] + + Examples:: + + bidirectional_mask = [[0, 1, 1, 1, 0, 0]] + bidirectional_block_mask = [[ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]] """ q_block_indices = _make_block_mask_indices(bidirectional_mask) kv_block_indices = q_block_indices @@ -556,23 +559,26 @@ def generate_attention_mask( allowed to attend to each other. The masking logic can enforce: - 1. **Sequence Separation:** Using `decoder_segment_ids`, attention is - confined within distinct sequences in a batch. This is crucial when - multiple unrelated sequences are packed together. - 2. **Causality:** Preventing attention to future positions. This is - standard for autoregressive decoding. For chunked prefill, as - described in the SARATHI paper [2], causality is adjusted based - on `previous_chunk` information. - 3. **Specialized Attention Patterns:** Depending on `self.attention_type`, - it can apply: - * Local Sliding Window Attention: Restricts attention to a - fixed-size window around each query position. - * Chunk Attention: Divides sequences into chunks and applies - masking at the chunk level. - 4. **Bidirectional Attention for Sub-sequences:** If `bidirectional_mask` - is provided (e.g., for image tokens in a multimodal model), - those parts of the sequence can attend bidirectionally, and this - mask is OR-ed with other generated masks. + + 1. **Sequence Separation:** Using `decoder_segment_ids`, attention is + confined within distinct sequences in a batch. This is crucial when + multiple unrelated sequences are packed together. + 2. **Causality:** Preventing attention to future positions. This is + standard for autoregressive decoding. For chunked prefill, as + described in the SARATHI paper [2], causality is adjusted based + on `previous_chunk` information. + 3. **Specialized Attention Patterns:** Depending on `self.attention_type`, + it can apply: + + * Local Sliding Window Attention: Restricts attention to a + fixed-size window around each query position. + * Chunk Attention: Divides sequences into chunks and applies + masking at the chunk level. + + 4. **Bidirectional Attention for Sub-sequences:** If `bidirectional_mask` + is provided (e.g., for image tokens in a multimodal model), + those parts of the sequence can attend bidirectionally, and this + mask is OR-ed with other generated masks. The overall approach and specific masking techniques are influenced by efficient attention mechanisms like those found in the Pallas MHA @@ -580,30 +586,30 @@ def generate_attention_mask( Args: query: The query tensor, typically of shape - `[batch_size, q_sequence_length, num_heads, head_dim]`. - Used primarily for deriving sequence length. + `[batch_size, q_sequence_length, num_heads, head_dim]`. + Used primarily for deriving sequence length. key: The key tensor, typically of shape - `[batch_size, kv_sequence_length, num_heads, head_dim]`. - Used primarily for deriving sequence length. + `[batch_size, kv_sequence_length, num_heads, head_dim]`. + Used primarily for deriving sequence length. decoder_segment_ids: Optional `Array` of shape `[batch_size, q_sequence_length]`. - Identifies distinct sequences within the batch. Attention is - restricted to elements within the same segment ID. In autoregressive - mode, specific values (e.g., `common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR`) - can mark the currently active sequence for decoding. + Identifies distinct sequences within the batch. Attention is + restricted to elements within the same segment ID. In autoregressive + mode, specific values (e.g., `common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR`) + can mark the currently active sequence for decoding. model_mode: A string (e.g., `common_types.MODEL_MODE_AUTOREGRESSIVE`, - `MODEL_MODE_PREFILL`) indicating the operational - mode. This significantly influences mask generation, particularly - how causality and segment separation are handled. + `MODEL_MODE_PREFILL`) indicating the operational + mode. This significantly influences mask generation, particularly + how causality and segment separation are handled. previous_chunk: Optional. Information about previously processed - key/value chunks, often a tensor representing the previous keys/values. - Used to correctly offset causal masks in chunked attention or - streaming scenarios. Its shape might be - `[batch_size, prev_kv_sequence_length, ...]`. + key/value chunks, often a tensor representing the previous keys/values. + Used to correctly offset causal masks in chunked attention or + streaming scenarios. Its shape might be + `[batch_size, prev_kv_sequence_length, ...]`. bidirectional_mask: Optional `Array` of shape `[batch_size, kv_sequence_length]`. - If provided, this boolean mask indicates tokens (e.g., image tokens) - that are allowed to attend bidirectionally. The resulting - block-wise bidirectional mask is combined with other masks using a - logical OR. + If provided, this boolean mask indicates tokens (e.g., image tokens) + that are allowed to attend bidirectionally. The resulting + block-wise bidirectional mask is combined with other masks using a + logical OR. Returns: An `Array` representing the attention mask, broadcastable to the shape @@ -614,10 +620,10 @@ def generate_attention_mask( the inputs and configuration. References: - [1] JAX Pallas MHA Flash Attention: - https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py - [2] SARATHI: Efficient LLM Inference by Piggybacking Decodes with - Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369) + [1]: JAX Pallas MHA Flash Attention: + https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py + [2]: SARATHI: Efficient LLM Inference by Piggybacking Decodes with + Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369) """ mask = None if model_mode == MODEL_MODE_AUTOREGRESSIVE: @@ -1361,6 +1367,7 @@ def cudnn_flash_attention( model_mode: str = MODEL_MODE_TRAIN, ) -> Array: """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports MHA, GQA, SWA, Packing and Context Parallelism 2. Context Parallelism currently only supports causal masking and no packing """ @@ -1496,6 +1503,7 @@ def compute_local_attention( sinks: Array | None = None, ) -> tuple[Array, Array, Array]: """Computes the attention of a local subset of the kv cache. + Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -1843,6 +1851,7 @@ def __call__( # pylint: disable=protected-access class LoadBalancedCausalMask(splash_attention_mask._ComputableMask): """Lazy causal mask, prevents the model from attending to future tokens. + Attributes: offset: Offset of q start wrt kv. A positive offset shifts the bottom triangle upward, a negative one shifts it downward. A negative offset diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 51f98d8de8..5d32ee101d 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -110,13 +110,11 @@ def random_routing(rng_key, gate_logits, num_experts_per_tok): Returns: A tuple containing: - - top_k_indices: JAX array of shape (batch_size, sequence_length, - num_experts_per_tok) - representing the indices of the selected experts for each - token. - - top_k_weights: JAX array of shape (batch_size, sequence_length, - num_experts_per_tok) - representing the weights for the selected experts. + + * top_k_indices: JAX array of shape `(batch_size, sequence_length, num_experts_per_tok)` + representing the indices of the selected experts for each token. + * top_k_weights: JAX array of shape `(batch_size, sequence_length, num_experts_per_tok)` + representing the weights for the selected experts. """ bs, seq_len, num_experts = gate_logits.shape selected_num = bs * seq_len * num_experts_per_tok @@ -525,9 +523,9 @@ def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) - pre_bias_logits: Array of shape `(batch, seq,num_experts)`. Returns: - - top_k_weights: `(batch, seq, num_experts_per_tok)` array of weight values for + top_k_weights: `(batch, seq, num_experts_per_tok)` array of weight values for each selected expert. - - top_k_indices: `(batch, seq, num_experts_per_tok)` array of indices + top_k_indices: `(batch, seq, num_experts_per_tok)` array of indices identifying the selected experts for each token. """ expert_mask = 1 if self.config.n_routing_groups == -1 else self.expert_group_mask(gate_logits) @@ -641,18 +639,17 @@ def local_permute( """Permutes tokens locally within an expert shard. This function prepares the input tokens for processing by the experts - located - on the current shard. It groups the tokens by their assigned local expert - index (0 to local_expert_size - 1). + located on the current shard. It groups the tokens by their assigned local + expert index (0 to `local_expert_size - 1`). Args: inputs: The input data (tokens) assigned to the experts on this shard. Shape `[tokens, emb_dim]`. global_group_sizes: The count of tokens assignments for each global expert - across all the batch shards. Shape `[num_batch_shards, num_experts]. + across all the batch shards. Shape `[num_batch_shards, num_experts]`. local_expert_size: The number of experts handled by the current shard. shard_index: The index of the current expert shard (0 to - num_expert_parallelism - 1). + `num_expert_parallelism - 1`). is_offset: If True, assumes `inputs` are pre-sorted by global expert ID and selects the slice relevant to this shard's assigned experts. If False, assumes that `inputs` corresponding to the shard's experts start @@ -662,11 +659,12 @@ def local_permute( Returns: A tuple containing: - sorted_inputs: Input data permuted local expert ID. - sorted_indices: Indices used to permute the inputs. - local_group_size: Number of tokens assigned to each local expert on this - shard. - sorted_experts_ids: expert ID corresponding to each token of the permuted + + * `sorted_inputs`: Input data permuted local expert ID. + * `sorted_indices`: Indices used to permute the inputs. + * `local_group_size`: Number of tokens assigned to each local expert on this + shard. + * `sorted_experts_ids`: expert ID corresponding to each token of the permuted inputs. """ diff --git a/src/MaxText/layers/normalizations.py b/src/MaxText/layers/normalizations.py index 358809a4ca..d81009189a 100644 --- a/src/MaxText/layers/normalizations.py +++ b/src/MaxText/layers/normalizations.py @@ -82,13 +82,13 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): """ - Used for input and post attention layernorms - in Qwen3NextDecoderLayer. + Used for input and post attention layernorms in `Qwen3NextDecoderLayer`. This normalization layer is specific to Qwen3-Next. Key characteristics: - 1. The learnable scale parameter `scale` is initialized to ZEROS. - 2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0. - This matches the PyTorch implementation of Qwen3NextRMSNorm. + + 1. The learnable scale parameter `scale` is initialized to ZEROS. + 2. The scale is applied as `(1.0 + self.scale)`, making the initial scale effectively 1.0. + This matches the PyTorch implementation of Qwen3NextRMSNorm. """ return nnx.data( RMSNorm( diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index c7284fb22c..d37827bb07 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -44,7 +44,7 @@ class Pipeline(nn.Module): config: Importantly contains num_pipeline_microbatches, num_pipeline_repeats. layers: A module instance that each stage can execute. It can either be a single layer such as a LlamaDecoderLayer instance or scanned/looped set of decoder layers to execute multiple layers per stage. - mesh: The device mesh of the system. + mesh: The device mesh of the system. remat_policy: Remat policy to use for the loop iterations """ @@ -87,14 +87,14 @@ def iterations_to_complete_first_microbatch(self): def init_states(self, inputs): """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover - Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] + Assumes input has already been reshaped into microbatches: `[num_micro_batches, micro_batch_size, sequence, embed]` Returns a dictionary with properties - shift: zeros shape [num_stages, micro_size, sequence, embed] - prev_outputs: same shape as shift, only used when pipeline_delay_activation_forwarding is set to true, else None - state_io: reshaped inputs [num_stages, microbatches/stages, micro_size, sequence, embed] - circ_storage: zeros [num_stages, microbatches, micro_size, sequence, embed] when needed, else None - circ_storage_mover: zeros[num_stages, micro_size, sequence, embed] when needed, else None + shift: zeros shape `[num_stages, micro_size, sequence, embed]` + prev_outputs: same shape as shift, only used when `pipeline_delay_activation_forwarding` is set to true, else None + state_io: reshaped inputs `[num_stages, microbatches/stages, micro_size, sequence, embed]` + circ_storage: zeros `[num_stages, microbatches, micro_size, sequence, embed]` when needed, else None + circ_storage_mover: zeros `[num_stages, micro_size, sequence, embed]` when needed, else None loop_iteration: scalar set initially to 0. """ @@ -167,9 +167,9 @@ def init_states(self, inputs): def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): """ Construct stages_in: the global array that is operated on for this iteration, shape same as - shift=[stages, micro_size, sequence, embed] + `shift=[stages, micro_size, sequence, embed]`. This is almost a rotated version of the last outputs, except for the first stage which must grab a new batch from - state_io or an old one from circ_storage + `state_io` or an old one from `circ_storage` """ # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) @@ -230,16 +230,19 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): """Use vmap to implement a sharded parallel gather. + Parallel gather means each stage has its own weights, and gets one slice from it. + Args: weights: Per-stage data to be gathered from. repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not have this dimension. stages_dim_in_weights: The dimension in weights that represents parallel stages. + Returns: The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights - removed. + removed. """ def _gather_one(x, repeat_id): @@ -280,6 +283,7 @@ def _gather_one(x, i): def get_new_loop_state(self, output, loop_state): """ Update the various buffers given the output of the most recent iteration + * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) * Pushing inputs up from top of state_io into first stage of shift * Pulling outputs up from last stage of shift into bottom of state_io @@ -385,7 +389,7 @@ def permute_output_micro_per_stage_dim(self, output): def get_current_stage_weights(self, pipeline_weights, loop_iteration): """ Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. - {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. + `{'mlp': 'wo': [stages, mlp, embed]}`. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. For non-circular pipelines, this simply returns all weights - every weight is used in every iteraiton. However for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 0929ec7757..0b7bf7b53d 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -349,9 +349,10 @@ def calculate_ffn_mamtul_tflops_per_device(config, mlp_dim): """Helper function to calculate matmul TFLOP in ffn based on MLP dimension. Applies to: - - Dense FFN layers (mlp_dim = config.mlp_dim). - - MoE FFN layers (mlp_dim = config.moe_mlp_dim), - need to scale by shared_experts or num_experts_per_tok. + + * Dense FFN layers (mlp_dim = config.mlp_dim). + * MoE FFN layers (mlp_dim = config.moe_mlp_dim), + need to scale by shared_experts or num_experts_per_tok. """ ffn1_flops = ( 2 * config.per_device_batch_size * config.max_target_length * mlp_dim * config.emb_dim * len(config.mlp_activations) @@ -391,10 +392,11 @@ def get_dense_moe_layers(config): def calculate_gemma3_vision_layers_tflops_per_device(config): """ Estimate TFLOPs for Gemma3 vision encoder (ViT-style). + Returns: - total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer) - learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections) - attention_tflops: TFLOPs from attention multiplications + total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer) + learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections) + attention_tflops: TFLOPs from attention multiplications """ # Config values B = config.per_device_batch_size @@ -445,10 +447,11 @@ def calculate_gemma3_vision_layers_tflops_per_device(config): def calculate_llama4_vision_layers_tflops_per_device(config): """ Estimate TFLOPs for Llama4 vision encoder (ViT-style). + Returns: - total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer) - learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections) - attention_tflops: TFLOPs from attention multiplications + total_tflops: Total TFLOPs (counts for fwd + bwd + optimizer) + learnable_weight_tflops: TFLOPs from learnable weights (patch embedding, qkv, MLP, projections) + attention_tflops: TFLOPs from attention multiplications """ # Config values B = config.per_device_batch_size @@ -698,12 +701,12 @@ def get_nested_value(dictionary, nested_key, default=None): Retrieves a value from a nested key in a dictionary. Args: - dictionary: The dictionary to search in. - nested_key: A tuple representing the nested key, e.g., ('level1', 'level2', 'key'). - default: The value to return if the nested key is not found. + dictionary: The dictionary to search in. + nested_key: A tuple representing the nested key, e.g., ('level1', 'level2', 'key'). + default: The value to return if the nested key is not found. Returns: - The value associated with the nested key, or the default value if not found. + The value associated with the nested key, or the default value if not found. """ current_level = dictionary @@ -769,6 +772,7 @@ def get_abstract_param(model, config): def setup_decode_state(model, config, rng, mesh, checkpoint_manager): """Setup decode state by loading params from a checkpoint. + Args: model: the flax model to initialize config: config object @@ -1102,12 +1106,16 @@ def create_device_mesh(config, devices=None): def create_learning_rate_schedule(config): - """Creates a warmup and cosine decay learning rate schedule: + """Creates a warmup and cosine decay learning rate schedule + We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 + Learning rate schedule has either two or three parts: - 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] - 2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps - 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. + + 1. Linear warmup from 0 to `[learning_rate]` over steps 0 to `[learning_rate_schedule_steps * warmup_steps_fraction]` + 2. Cosine from `[learning_rate]` to `[learning_rate * cosine_learning_rate_final_fraction]` until learning_rate_schedule_steps + 3. Constant learning rate of 0 from `learning_rate_schedule_steps` to steps. + The zero learning rate section can be used to more accurately measure the fully trained model's performance. """ diff --git a/src/MaxText/multimodal_utils.py b/src/MaxText/multimodal_utils.py index f4cf093056..23d841315d 100644 --- a/src/MaxText/multimodal_utils.py +++ b/src/MaxText/multimodal_utils.py @@ -114,10 +114,12 @@ def load_image_from_path(image_path): def _normalize_images(images, mean, std): """Normalize the image to zero mean and unit variance. Change the image mean and std based on parameters mean and std. + Args: images: The images to normalize. mean: tuple[float, float, float]. std: tuple[float, float, float]. + Returns: The normalized images. """ @@ -130,8 +132,10 @@ def get_factors(dividend: int): """ Calculate all factors of a given number, i.e. a divisor that leaves no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}. + Args: dividend (int): The number to find factors for. + Returns: set: A set containing all factors of the number. """ @@ -170,11 +174,13 @@ def get_best_resolution( ) -> tuple[int, int]: """ Get the best resolution for the image based on the possible resolutions. + Args: img_height (int): The height of the image. image_width (int): The width of the image. possible_resolutions (list): A list of possible resolutions. resize_to_max_canvas (bool): Whether to resize to max canvas or not. + Returns: tuple: The best resolution for the image. """ @@ -379,11 +385,15 @@ def pre_process_llama4_image(image: np.ndarray | list[np.ndarray]) -> Preprocess """ Pre-process image for Llama4 model. Find best resolution and split into tiles with an additional global tile. Original implementation from image_processing_llama4.py: http://shortn/_VXLgQ1lmkz + Args: image: The np.array image [H, W, C] or images [N, H, W, C] to pre-process. + Returns: The pre-processed image in np.array [N, NUM_TILES, C, TILE_SIZE, TILE_SIZE]. + Example: + image of (536, 640, 3), its best_resolution = (672, 672), image split into 4 tiles of (336, 336) Additional global tile of (336, 336) is added, and the final output image_tiles is (1, 5, 3, 336, 336). """ @@ -457,9 +467,11 @@ def pre_process_llama4_image(image: np.ndarray | list[np.ndarray]) -> Preprocess def pre_process_image(image, model_name): """Pre-process image according to different model's requirements. + Args: image: The np.array image [H, W, C] or images [N, H, W, C] to pre-process. model_name: The config.model_name that specifies the image preprocess ways. + Returns: The PreprocessorOutput instance containing image in np.array [H, W, C] or [N, H, W, C]. """ @@ -613,15 +625,17 @@ def add_extra_tokens_for_images_llama4(tokens, processor_output: PreprocessorOut def get_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk): """Constructs the token sequence for a single image in Llama4. + This function generates a list of special tokens that represent an image, including its tiled structure (if applicable) and a global representation. The sequence includes: - - A beginning-of-image token. - - Patch tokens for each local tile, interspersed with tile separators - if the image is divided into multiple tiles (ratio_h * ratio_w > 1). - - A fake image token placeholder for the global image representation. - - Patch tokens associated with the global image representation. - - An end-of-image token. + + * A beginning-of-image token. + * Patch tokens for each local tile, interspersed with tile separators + if the image is divided into multiple tiles (`ratio_h * ratio_w > 1`). + * A fake image token placeholder for the global image representation. + * Patch tokens associated with the global image representation. + * An end-of-image token. Args: this_aspect_ratio: A tuple (ratio_h, ratio_w) representing the number @@ -635,21 +649,22 @@ def get_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk): Example: If `this_aspect_ratio` is [2, 2] and `num_patches_per_chunk` is 4, - the output will be: - [ - LLAMA4_BEGIN_IMAGE_TOKEN, - LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, - LLAMA4_TILE_X_SEPARATOR_TOKEN, - LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, - LLAMA4_TILE_Y_SEPARATOR_TOKEN, - LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, - LLAMA4_TILE_X_SEPARATOR_TOKEN, - LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, - LLAMA4_TILE_Y_SEPARATOR_TOKEN, - LLAMA4_FAKE_IMAGE_TOKEN, - LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, - LLAMA4_END_IMAGE_TOKEN - ], total 27 tokens. + the output will be:: + + [ + LLAMA4_BEGIN_IMAGE_TOKEN, + LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, + LLAMA4_TILE_X_SEPARATOR_TOKEN, + LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, + LLAMA4_TILE_Y_SEPARATOR_TOKEN, + LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, + LLAMA4_TILE_X_SEPARATOR_TOKEN, + LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, + LLAMA4_TILE_Y_SEPARATOR_TOKEN, + LLAMA4_FAKE_IMAGE_TOKEN, + LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, LLAMA4_PATCH_TOKEN, + LLAMA4_END_IMAGE_TOKEN + ], total 27 tokens. """ img_tokens = [LLAMA4_BEGIN_IMAGE_TOKEN] @@ -710,16 +725,14 @@ def add_extra_tokens_for_images_gemma3( If the model has images, we expand each `` token by the image placeholder tokens. - Example: + Example:: - ```python - input = [..., x, , y, ...] - output = [ - ..., x, \n\n, , SOFT_TOKEN_PLACEHOLDER, - SOFT_TOKEN_PLACEHOLDER, ..., SOFT_TOKEN_PLACEHOLDER, - SOFT_TOKEN_PLACEHOLDER, , \n\n, y, ... - ] - ``` + input = [..., x, , y, ...] + output = [ + ..., x, \n\n, , SOFT_TOKEN_PLACEHOLDER, + SOFT_TOKEN_PLACEHOLDER, ..., SOFT_TOKEN_PLACEHOLDER, + SOFT_TOKEN_PLACEHOLDER, , \n\n, y, ... + ] The `\n\n` tokens are added to match how the model was trained. @@ -863,72 +876,79 @@ def merge_mm_embeddings( mask, image_masks: np.ndarray | jnp.ndarray | None = None, ) -> np.ndarray | jnp.ndarray: - """Merges text and vision embeddings based on a mask. + """Merges text and vision embeddings based on a mask. - This function handles two primary formats for vision embeddings: - 1. Tiled Format (e.g., Llama4): Vision embeddings are provided as a batch of - images and their tiles, with shape (B * N, T, K, D). These are flattened - into a single sequence of vision tokens per batch item. - 2. Simple Format (e.g., Gemma3): Vision embeddings are provided as - (B, N, K, D) and are flattened into a sequence of vision tokens. + This function handles two primary formats for vision embeddings: - Args: - text_embeddings: (B, S, D) array of text embeddings. - vision_embeddings: Vision embeddings in one of two formats: - - (B * N, T, K, D) for tiled inputs. - - (B, N, K, D) for simple inputs. - (B=batch_size, S=seq_len, D=embedding_dim, N=num_images, - T=num_tiles, K=toks_per_image) - mask: (B, S) boolean or integer array where non-zero positions - indicate where vision embeddings should be placed. - image_masks: (Optional) A mask for the vision tokens. - - (B * N, T) for tiled inputs, indicating valid tiles. - - If None, all vision embeddings are assumed to be valid. + 1. Tiled Format (e.g., Llama4): Vision embeddings are provided as a batch of + images and their tiles, with shape (B * N, T, K, D). These are flattened + into a single sequence of vision tokens per batch item. + 2. Simple Format (e.g., Gemma3): Vision embeddings are provided as + (B, N, K, D) and are flattened into a sequence of vision tokens + (B=batch_size, S=seq_len, D=embedding_dim, N=num_images, T=num_tiles, + K=toks_per_image) - Returns: - A (B, S, D) array of merged embeddings. - """ - # Input Validation and Shape Unpacking - batch_size, _, d_model = text_embeddings.shape - # The number of tokens per image/tile is the second to last dimension. - num_toks_per_image = vision_embeddings.shape[-2] - - if d_model != vision_embeddings.shape[-1]: - raise ValueError( - "Embedding dimension mismatch between text and vision embeddings:" f" {d_model} vs {vision_embeddings.shape[-1]}" - ) + Args: + text_embeddings: (B, S, D) array of text embeddings. + vision_embeddings: Vision embeddings in one of two formats: + + * (B * N, T, K, D) for tiled inputs. + * (B, N, K, D) for simple inputs. - # Reshape Vision Embeddings to a unified (B, S_vision, D) format - # This single reshape robustly handles both documented cases: - # Case 1: (B * N, T, K, D) -> (B, N*T*K, D) - # Case 2: (B, N, K, D) -> (B, N*K, D) - flat_vision_embeddings = vision_embeddings.reshape(batch_size, -1, d_model) - - # Process Optional Image Masks - flat_image_token_masks = None - if image_masks is not None: - # Handle the tiled case where image_masks batch dimension is (B * N) - if image_masks.shape[0] != batch_size: - if image_masks.shape[0] % batch_size != 0: + mask: (B, S) boolean or integer array where non-zero positions + indicate where vision embeddings should be placed. + image_masks: (Optional) A mask for the vision tokens. + + * (B * N, T) for tiled inputs, indicating valid tiles. + * If None, all vision embeddings are assumed to be valid. + + Returns: + A (B, S, D) array of merged embeddings. + """ + # Input Validation and Shape Unpacking + batch_size, _, d_model = text_embeddings.shape + # The number of tokens per image/tile is the second to last dimension. + num_toks_per_image = vision_embeddings.shape[-2] + + if d_model != vision_embeddings.shape[-1]: raise ValueError( - "Batch dimension of image_masks must be a multiple of the text" - f" batch size. Got {image_masks.shape[0]} and {batch_size}." + "Embedding dimension mismatch between text and vision embeddings:" + f" {d_model} vs {vision_embeddings.shape[-1]}" ) - # Reshape from (B * N, T) to (B, N * T) - flat_image_tile_masks = image_masks.reshape(batch_size, -1) - else: - # This handles cases where image_masks is already (B, ...) - flat_image_tile_masks = image_masks.reshape(batch_size, -1) - - # Expand the tile-level mask to a token-level mask to match the embeddings. - # A mask of shape (B, N*T) becomes (B, N*T*K) by repeating each element K times. - flat_image_token_masks = jnp.repeat(flat_image_tile_masks, repeats=num_toks_per_image, axis=1) - - # Vmap the inner merge function over the batch dimension - return jax.vmap( - _merge_mm_embeddings_inner, # Assumes this function is defined elsewhere - in_axes=(0, 0, 0, None if flat_image_token_masks is None else 0), - )(text_embeddings, flat_vision_embeddings, mask, flat_image_token_masks) + + # Reshape Vision Embeddings to a unified (B, S_vision, D) format + # This single reshape robustly handles both documented cases: + # Case 1: (B * N, T, K, D) -> (B, N*T*K, D) + # Case 2: (B, N, K, D) -> (B, N*K, D) + flat_vision_embeddings = vision_embeddings.reshape(batch_size, -1, d_model) + + # Process Optional Image Masks + flat_image_token_masks = None + if image_masks is not None: + # Handle the tiled case where image_masks batch dimension is (B * N) + if image_masks.shape[0] != batch_size: + if image_masks.shape[0] % batch_size != 0: + raise ValueError( + "Batch dimension of image_masks must be a multiple of the text" + f" batch size. Got {image_masks.shape[0]} and {batch_size}." + ) + # Reshape from (B * N, T) to (B, N * T) + flat_image_tile_masks = image_masks.reshape(batch_size, -1) + else: + # This handles cases where image_masks is already (B, ...) + flat_image_tile_masks = image_masks.reshape(batch_size, -1) + + # Expand the tile-level mask to a token-level mask to match the embeddings. + # A mask of shape (B, N*T) becomes (B, N*T*K) by repeating each element K times. + flat_image_token_masks = jnp.repeat( + flat_image_tile_masks, repeats=num_toks_per_image, axis=1 + ) + + # Vmap the inner merge function over the batch dimension + return jax.vmap( + _merge_mm_embeddings_inner, # Assumes this function is defined elsewhere + in_axes=(0, 0, 0, None if flat_image_token_masks is None else 0), + )(text_embeddings, flat_vision_embeddings, mask, flat_image_token_masks) def _merge_mm_embeddings_inner( diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 856b0065dc..44663dbafd 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -1285,9 +1285,11 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str): def get_individual_scales(scale): """Choose appropriate scales for individual dimensions based on global scale We choose to rotate between doubling: - num_head and mlp_dim - embed_dim - num_layers + + * num_head and mlp_dim + * embed_dim + * num_layers + Any one of these steps is not a perfect doubling, although going through a cycle of three is a near perfect 8x scaling except for the linear -> softmax -> output step""" diff --git a/src/MaxText/sequence_packing.py b/src/MaxText/sequence_packing.py index 74e5c09ea5..5a4387dd14 100644 --- a/src/MaxText/sequence_packing.py +++ b/src/MaxText/sequence_packing.py @@ -29,31 +29,39 @@ def pack_dataset( Each example in the output dataset represents several examples in the input dataset. For each key in the input dataset, two additional keys are created: - _segmentation: an int32 tensor identifying the parts - representing the original example. - _position: an int32 tensor identifying the position within the original - example. + + * `_segmentation`: an int32 tensor identifying the parts representing the + original example. + * `_position`: an int32 tensor identifying the position within the + original example. + Example: Two input examples get combined to form an output example. - The input examples are: - {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} - {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} - The output example is: - { - "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] - "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] - "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] - "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] - "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] - "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] - } + The input examples are:: + + {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} + {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} + + The output example is:: + + { + "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] + "inputs_segmentation": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] + "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] + "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] + "targets_segmentation": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] + "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] + } + 0 represents padding in both the inputs and the outputs. Sequences in the incoming examples are truncated to length "length", and the sequences in the output examples all have fixed (padded) length "length". + Args: dataset: a tf.data.Dataset key2length: an integer, or a dict from feature-key to integer keys: a list of strings (e.g. ["inputs", "targets"]) + Returns: a tf.data.Dataset """ diff --git a/src/MaxText/sft/sft_trainer.py b/src/MaxText/sft/sft_trainer.py index 6e55187444..ee42d4e0e6 100644 --- a/src/MaxText/sft/sft_trainer.py +++ b/src/MaxText/sft/sft_trainer.py @@ -18,7 +18,9 @@ are defined inside `src/MaxText/configs/sft.yml`. Example command: -Training & Evaluation: + +Training & Evaluation:: + python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ @@ -26,7 +28,8 @@ per_device_batch_size=1 max_target_length=1024 \ eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 -Training: +Training:: + python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ diff --git a/src/MaxText/tokenizer.py b/src/MaxText/tokenizer.py index 691179b387..3d3b5c9637 100644 --- a/src/MaxText/tokenizer.py +++ b/src/MaxText/tokenizer.py @@ -97,18 +97,19 @@ def encode( s (str): The input string to be encoded. bos (bool): Whether to prepend the beginning-of-sequence token. eos (bool): Whether to append the end-of-sequence token. - allowed_tokens ("all"|set[str]): allowed special tokens in string - disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string + allowed_tokens (`"all"|set[str]`): allowed special tokens in string + disallowed_tokens (`"all"|set[str]`): special tokens that raise an error when in string Returns: list[int]: A list of token IDs. By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding + + * Setting `disallowed_special` to () will cause all text corresponding to special tokens to be encoded as natural text (insteading of raising an error). - - Setting `allowed_special` to "all" will treat all text corresponding + * Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. """ assert isinstance(s, str) diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index c423b79478..f38f21c173 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -214,12 +214,13 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): This mapping is derived by matching the provided config dictionary against the model's parameter dump. - To check this mapping, dump the huggingface model shapes: - from transformers import AutoModelForCausalLM - model_name = "deepseek-ai/DeepSeek-V3" - model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") - for name, val in model.named_parameters(): - print(name, val.shape) + To check this mapping, dump the huggingface model shapes:: + + from transformers import AutoModelForCausalLM + model_name = "deepseek-ai/DeepSeek-V3" + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") + for name, val in model.named_parameters(): + print(name, val.shape) Args: config (dict): Model configuration dictionary (from HF DeepseekV3Config.to_dict()) @@ -436,7 +437,8 @@ def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): def QWEN3_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace Qwen3 weights path and the HuggingFace weights shape. - To check this mapping, dump the huggingface model shapes: + To check this mapping, dump the huggingface model shapes:: + from transformers import AutoModelForCausalLM model_name = "Qwen/Qwen3-0.6B" model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index afa6c6f631..931cf24e99 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -119,9 +119,10 @@ def process_maxtext_param( a single tensor or a list of tensors for N-to-1 mappings) and transforming it into one or more Hugging Face compatible parameters. It handles various scenarios including: - - 1-to-1 mappings (single MaxText param to single HF param). - - N-to-1 mappings (multiple MaxText params combined into a single HF param). - - Stacked MaxText parameters (e.g., scanned layers or MoE experts) that need + + * 1-to-1 mappings (single MaxText param to single HF param). + * N-to-1 mappings (multiple MaxText params combined into a single HF param). + * Stacked MaxText parameters (e.g., scanned layers or MoE experts) that need to be unstacked into individual Hugging Face parameters. Args: @@ -505,8 +506,8 @@ def save_model_files( ): """ Saves model files (config and weights) to the specified directory. - When uploading to GCS/HF hub, - *.safetensors are uploaded from memory to remote, no local storage is used to save disk usage + When uploading to GCS/HF hub, `*.safetensors` are uploaded from memory to + remote, no local storage is used to save disk usage """ if output_dir.startswith("hf://"): From d7c1e2ed2105a6499cd6238155e9ada65b7330d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Fri, 19 Dec 2025 17:44:04 -0300 Subject: [PATCH 2/3] Formatting fixes --- docs/conf.py | 126 +++++++++--------- src/MaxText/checkpointing.py | 12 +- src/MaxText/estimator.py | 20 +-- .../baselines/context/hf_shape.py | 9 +- .../baselines/context/param_mapping.py | 53 ++++---- .../code_evaluation_agent.py | 21 +-- .../agent/code_evaluation_agent/utils.py | 23 ++-- .../agent/code_generation_agent/llm_agent.py | 7 +- .../llm_code_generation.py | 19 +-- .../make_pytorch_file.py | 32 ++--- .../database_operations.py | 47 +++---- .../integrative_rag_agent/get_model_info.py | 13 +- .../llm_rag_code_conversion.py | 25 ++-- .../llm_rag_embedding_generation.py | 9 +- .../scrap_all_python_blocks.py | 21 +-- .../sort_components_in_hierarchical_order.py | 13 +- .../agent/integrative_rag_agent/utils.py | 2 +- .../get_files_in_hierarchical_order.py | 35 +++-- .../orchestration_agent/split_python_file.py | 31 ++--- .../agent/orchestration_agent/utils.py | 108 +++++++-------- .../self_debugging_agent.py | 83 ++++++------ .../agent/self_debugging_agent/utils.py | 42 +++--- src/MaxText/experimental/rl/grpo_trainer.py | 43 +++--- src/MaxText/experimental/rl/grpo_utils.py | 11 +- src/MaxText/gradient_accumulation.py | 9 +- src/MaxText/inference/offline_engine.py | 34 ++--- src/MaxText/inference/paged_attention.py | 14 +- .../inference/scripts/sharding_utils.py | 45 ++++--- .../input_pipeline/_input_pipeline_utils.py | 34 ++--- .../_tfds_data_processing_c4_mlperf.py | 10 ++ .../vllm/maxtext_vllm_adapter/adapter.py | 18 +-- src/MaxText/layers/attention_mla.py | 7 +- src/MaxText/layers/attention_op.py | 45 ++++--- src/MaxText/layers/attentions.py | 5 +- src/MaxText/layers/decoders.py | 2 +- src/MaxText/layers/embeddings.py | 22 +-- src/MaxText/layers/gemma3.py | 2 + src/MaxText/layers/llama4.py | 9 +- src/MaxText/layers/multi_token_prediction.py | 14 +- src/MaxText/layers/pipeline.py | 2 +- src/MaxText/layers/qwen3.py | 49 +++---- src/MaxText/max_utils.py | 6 +- src/MaxText/model_creation_utils.py | 13 +- src/MaxText/multimodal_utils.py | 66 ++++----- src/MaxText/rl/evaluate_rl.py | 34 ++--- src/MaxText/sequence_packing.py | 10 +- src/MaxText/sharding.py | 20 +-- src/MaxText/tokenizer.py | 16 +-- src/MaxText/train_tokenizer.py | 4 + src/MaxText/vocabulary_tiling.py | 1 + 50 files changed, 683 insertions(+), 613 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8d633d16eb..4c028c2d9f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,68 +110,68 @@ # -- Autogenerate API documentation ------------------------------------------ def run_apidoc(_): - """Runs sphinx-apidoc to generate API documentation. - - This function is connected to the Sphinx build process and is triggered to - automatically generate the reStructuredText (RST) files for the API - documentation from the docstrings in the MaxText source code. - - Args: - _: The Sphinx application object. Not used. - """ - # directly within the Sphinx process, especially on macOS, as it avoids - # potential multiprocessing/forking issues like the "mutex lock failed" error. - # pylint: disable=import-outside-toplevel - import subprocess - - os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "1" - - assert os.path.isfile(os.path.join(MAXTEXT_REPO_ROOT, "pyproject.toml")) - - # The path where the generated RST files will be stored - output_path = os.path.join(MAXTEXT_REPO_ROOT, "docs", "reference", "api_generated") - - # Command to run sphinx-apidoc - # Note: We use `sys.executable -m sphinx.ext.apidoc` to ensure we're using - # the apidoc from the same Python environment as Sphinx. - command = [ - sys.executable, - "-m", - "sphinx.ext.apidoc", - "--module-first", - "--force", - "--separate", - "--output-dir", - output_path, - os.path.join(MAXTEXT_REPO_ROOT, "src"), - # Paths to exclude - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_mlperf"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "scratch_code"), - # Paths to exclude due to import errors - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "decode_multi.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "offline_engine.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "benchmark_chunked_prefill.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "decode.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark_sweep.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "load_and_quantize_checkpoint.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_config.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_server.py"), - os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "prefill_packing.py"), - ] - - # Run the command and check for errors - try: - print("Running sphinx-apidoc...") - subprocess.check_call( - command, env={**os.environ, **{"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "1"}} - ) - except subprocess.CalledProcessError as e: - print(f"sphinx-apidoc failed with error: {e}", file=sys.stderr) - sys.exit(1) + """Runs sphinx-apidoc to generate API documentation. + + This function is connected to the Sphinx build process and is triggered to + automatically generate the reStructuredText (RST) files for the API + documentation from the docstrings in the MaxText source code. + + Args: + _: The Sphinx application object. Not used. + """ + # directly within the Sphinx process, especially on macOS, as it avoids + # potential multiprocessing/forking issues like the "mutex lock failed" error. + # pylint: disable=import-outside-toplevel + import subprocess + + os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "1" + + assert os.path.isfile(os.path.join(MAXTEXT_REPO_ROOT, "pyproject.toml")) + + # The path where the generated RST files will be stored + output_path = os.path.join(MAXTEXT_REPO_ROOT, "docs", "reference", "api_generated") + + # Command to run sphinx-apidoc + # Note: We use `sys.executable -m sphinx.ext.apidoc` to ensure we're using + # the apidoc from the same Python environment as Sphinx. + command = [ + sys.executable, + "-m", + "sphinx.ext.apidoc", + "--module-first", + "--force", + "--separate", + "--output-dir", + output_path, + os.path.join(MAXTEXT_REPO_ROOT, "src"), + # Paths to exclude + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "experimental"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_mlperf"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "scratch_code"), + # Paths to exclude due to import errors + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "utils"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "decode_multi.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference", "offline_engine.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "benchmark_chunked_prefill.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "decode.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "inference_microbenchmark_sweep.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "load_and_quantize_checkpoint.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_config.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "maxengine_server.py"), + os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "prefill_packing.py"), + ] + + # Run the command and check for errors + try: + print("Running sphinx-apidoc...") + subprocess.check_call( + command, env={**os.environ, **{"OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "1"}} + ) + except subprocess.CalledProcessError as e: + print(f"sphinx-apidoc failed with error: {e}", file=sys.stderr) + sys.exit(1) class FilterSphinxWarnings(logging.Filter): @@ -192,6 +192,8 @@ def filter(self, record: logging.LogRecord) -> bool: def setup(app): + """Set up the Sphinx application with custom behavior.""" + # Connect the apidoc generation to the Sphinx build process run_apidoc(None) print("running:", app) diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 27fb674ff3..60a858a59e 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -514,13 +514,13 @@ def load_state_if_possible( enable_orbax_v1: bool flag for enabling Orbax v1. checkpoint_conversion_fn: function for converting checkpoint to Orbax v1. source_checkpoint_layout: Optional checkpoint context to use for loading, - provided in string format with the default being "orbax". + provided in string format with the default being "orbax". Returns: - A tuple of (train_state, train_state_params) where full_train_state captures - a full reload and train_state_params just the params for a partial reload. - At most one will be non-None. Both can be None if neither checkpoint is - set. + A tuple of `(train_state, train_state_params)` + where full_train_state captures a full reload and `train_state_params` + just the params for a partial reload. At most one will be non-None. Both + can be None if neither checkpoint is set. """ if checkpoint_manager is not None: @@ -615,8 +615,10 @@ def map_to_pspec(data): def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-error """Setup checkpoint logger. + Args: config + Returns: CloudLogger """ diff --git a/src/MaxText/estimator.py b/src/MaxText/estimator.py index 8cdb3c0d0b..1d4cece95d 100644 --- a/src/MaxText/estimator.py +++ b/src/MaxText/estimator.py @@ -177,7 +177,7 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int: max_pdb: The maximum per_device_batch_size to test. Returns: - The largest per_device_batch_size within the range that does not result in an OOM error. + The largest `per_device_batch_size` within the range that does not result in an OOM error. """ print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.") @@ -345,8 +345,9 @@ def get_parameter_value(config_tuple, prefix): Returns: A tuple of (bool, str or None). - - (True, value) if the prefix is found. - - (False, None) if the prefix is not found. + + * (True, value) if the prefix is found. + * (False, None) if the prefix is not found. """ for item in config_tuple: if item.startswith(prefix): @@ -364,12 +365,13 @@ def find_batch_size(base_argv): Parses the base arguments to find the 'per_device_batch_size'. Args: - base_argv: The tuple of command-line arguments. + base_argv: The tuple of command-line arguments. Returns: - A tuple of (bool, int or None): - - (True, batch_size) if 'per_device_batch_size=...' was found. - - (False, None) if it was not found. + A tuple of (bool, int or None) + + * (True, batch_size) if `per_device_batch_size=...` was found. + * (False, None) if it was not found. """ pdb_provided, pdb_str = get_parameter_value(base_argv, prefix="per_device_batch_size=") @@ -384,10 +386,10 @@ def find_remat_policy_tensor_names(base_argv): to be considered for rematerialization. Args: - base_argv: The tuple of command-line arguments. + base_argv: The tuple of command-line arguments. Returns: - A list of tensor names that were passed as flags. + A list of tensor names that were passed as flags. """ full_tensor_list = [ "context", diff --git a/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/hf_shape.py b/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/hf_shape.py index 29b7476961..b969505a5d 100644 --- a/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/hf_shape.py +++ b/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/hf_shape.py @@ -21,12 +21,13 @@ def GEMMA2_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace weights path and weights shape. Args: - config (dict): Model configuration dictionary, defined in `model_configs.py` + config (dict): Model configuration dictionary, defined in `model_configs.py` Returns: - dict: A mapping where: - - Keys are HuggingFace model parameter paths - - Values are parameter shape as a List + dict: A mapping where: + + * Keys are HuggingFace model parameter paths + * Values are parameter shape as a List """ mapping = { diff --git a/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/param_mapping.py b/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/param_mapping.py index cd251bebf2..2e0e919cc9 100644 --- a/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/param_mapping.py +++ b/src/MaxText/experimental/agent/ckpt_conversion_agent/baselines/context/param_mapping.py @@ -23,24 +23,24 @@ def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False): """Returns mapping between MaxText and HuggingFace Gemma2 weight paths. Args: - config (dict): Model configuration dictionary containing at least - 'num_hidden_layers'. - scan_layers (bool, optional): Whether the MaxText model uses layer - scanning optimization. When True, decoder layers are stacked into a - single tensor. Defaults to False. + config (dict): Model configuration dictionary containing at least + 'num_hidden_layers'. + scan_layers (bool, optional): Whether the MaxText model uses layer + scanning optimization. When True, decoder layers are stacked into a + single tensor. Defaults to False. Returns: - dict: A mapping where keys are MaxText parameter paths and values are - either single strings (HF parameter path) for unscanned parameters or - lists of strings (HF parameter paths) for stacked layers when - `scan_layers=True`. + dict: A mapping where keys are MaxText parameter paths and values are + either single strings (HF parameter path) for unscanned parameters or + lists of strings (HF parameter paths) for stacked layers when + `scan_layers=True`. Notes: - - MaxText uses a paired layer approach where two HF decoder layers are - treated as one MaxText decoder layer. - - MaxText layer `i` corresponds to HF layers `2i` and `2i+1`. - - Local components map to even-numbered HF decoder layers (0, 2, 4...). - - Global components map to odd-numbered HF decoder layers (1, 3, 5...). + * MaxText uses a paired layer approach where two HF decoder layers are + treated as one MaxText decoder layer. + * MaxText layer `i` corresponds to HF layers `2i` and `2i+1`. + * Local components map to even-numbered HF decoder layers (0, 2, 4...). + * Global components map to odd-numbered HF decoder layers (1, 3, 5...). """ nlayers = config["num_hidden_layers"] @@ -160,20 +160,21 @@ def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=F Gemma2, including operations like padding, reshaping, and scaling. Args: - config (dict): Model configuration dictionary that must contain: - - num_hidden_layers (int): Number of layers in the model. - - head_dim (int): Dimension of attention heads. - - hidden_size (int): Model's hidden dimension size. - scan_layers (bool, optional): Controls the output format for layer - parameters. True for batched, False for individual. Defaults to False. - saving_to_hf (bool, optional): Determines the direction of transformation. - True for MaxText to HuggingFace, False for the reverse. Defaults to - False. + config (dict): Model configuration dictionary that must contain: + + * num_hidden_layers (int): Number of layers in the model. + * head_dim (int): Dimension of attention heads. + * hidden_size (int): Model's hidden dimension size. + scan_layers (bool, optional): Controls the output format for layer + parameters. True for batched, False for individual. Defaults to False. + saving_to_hf (bool, optional): Determines the direction of transformation. + True for MaxText to HuggingFace, False for the reverse. Defaults to + False. Returns: - dict: A mapping from MaxText parameter names to transformation functions. - The value can be a single function or a list of functions to be - applied sequentially. + dict: A mapping from MaxText parameter names to transformation functions. + The value can be a single function or a list of functions to be + applied sequentially. """ nlayers = config["num_hidden_layers"] diff --git a/src/MaxText/experimental/agent/code_evaluation_agent/code_evaluation_agent.py b/src/MaxText/experimental/agent/code_evaluation_agent/code_evaluation_agent.py index 588cae0098..6e94481a60 100644 --- a/src/MaxText/experimental/agent/code_evaluation_agent/code_evaluation_agent.py +++ b/src/MaxText/experimental/agent/code_evaluation_agent/code_evaluation_agent.py @@ -76,13 +76,14 @@ def get_file_pairs(pytorch_path, jax_path): directories, filtering out any files in the JAX directory that start with "__". Args: - pytorch_path: The path to the directory containing PyTorch files. - jax_path: The path to the directory containing JAX files. + pytorch_path: The path to the directory containing PyTorch files. + jax_path: The path to the directory containing JAX files. Returns: - A tuple containing two lists of strings: - - The first list contains the full paths to the common PyTorch files. - - The second list contains the full paths to the common JAX files. + A tuple containing two lists of strings + + * The first list contains the full paths to the common PyTorch files. + * The second list contains the full paths to the common JAX files. """ pytorch_files = os.listdir(pytorch_path) jax_files = list(filter(lambda x: not x.startswith("__"), os.listdir(jax_path))) @@ -99,12 +100,12 @@ def make_test_case_and_run(args, python_file, jax_file): a penalty is applied. Args: - args (argparse.Namespace): The command-line arguments. - python_file: The path to the PyTorch code file. - jax_file: The path to the JAX code file. + args (argparse.Namespace): The command-line arguments. + python_file: The path to the PyTorch code file. + jax_file: The path to the JAX code file. Returns: - A tuple containing the number of passed and failed test cases. + A tuple containing the number of passed and failed test cases. """ response = None try: @@ -200,7 +201,7 @@ def parse_args(): Parses command-line arguments for file or folder processing. Returns: - argparse.Namespace: The parsed command-line arguments. + argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Code Evaluation Agent") parser.add_argument( diff --git a/src/MaxText/experimental/agent/code_evaluation_agent/utils.py b/src/MaxText/experimental/agent/code_evaluation_agent/utils.py index 585a0f4682..148fb672d7 100644 --- a/src/MaxText/experimental/agent/code_evaluation_agent/utils.py +++ b/src/MaxText/experimental/agent/code_evaluation_agent/utils.py @@ -57,11 +57,11 @@ def get_last_defined_module(code_str): `ast.ClassDef` node. Args: - code_str: A string containing Python code. + code_str: A string containing Python code. Returns: - The name of the last defined function or class, or a syntax error message - if the code is invalid. + The name of the last defined function or class, or a syntax error message + if the code is invalid. """ try: tree = ast.parse(code_str) @@ -87,16 +87,17 @@ def run_pytest_capture_output(test_file: str, code_folder: None | str = None) -> to capture all print statements and error messages from the test run. Args: - test_file: The path to the pytest file to run. - code_folder: The directory to change into before running the tests. + test_file: The path to the pytest file to run. + code_folder: The directory to change into before running the tests. Returns: - A tuple containing: - - output (str): The complete stdout and stderr from the test run. - - exit_code (int): The exit code of the pytest process (0 for success, non-zero otherwise). - - is_dependency_error (bool): True if a common dependency error was found in the output. - - passed (int): The number of tests that passed. - - failed (int): The number of tests that failed. + A tuple containing + + * output (str): The complete stdout and stderr from the test run. + * exit_code (int): The exit code of the pytest process (0 for success, non-zero otherwise). + * is_dependency_error (bool): True if a common dependency error was found in the output. + * passed (int): The number of tests that passed. + * failed (int): The number of tests that failed. """ current_path = os.path.abspath(".") try: diff --git a/src/MaxText/experimental/agent/code_generation_agent/llm_agent.py b/src/MaxText/experimental/agent/code_generation_agent/llm_agent.py index fe7815b6d1..006232b80b 100644 --- a/src/MaxText/experimental/agent/code_generation_agent/llm_agent.py +++ b/src/MaxText/experimental/agent/code_generation_agent/llm_agent.py @@ -73,12 +73,11 @@ def __call__(self, memory_list): chat dictionary format. Args: - memory_list (str | list): A single message string or a list of - message dictionaries in the required - model format. + memory_list (str | list): A single message string or a list of message + dictionaries in the required model format. Returns: - google.generativeai.types.GenerateContentResponse: The response from the model. + google.generativeai.types.GenerateContentResponse: The response from the model. """ if isinstance(memory_list, str): memory_list = {"role": "user", "parts": memory_list} diff --git a/src/MaxText/experimental/agent/code_generation_agent/llm_code_generation.py b/src/MaxText/experimental/agent/code_generation_agent/llm_code_generation.py index b702648cbb..e4395bfebf 100644 --- a/src/MaxText/experimental/agent/code_generation_agent/llm_code_generation.py +++ b/src/MaxText/experimental/agent/code_generation_agent/llm_code_generation.py @@ -106,10 +106,10 @@ def get_chat_dict(input_message=""): Creates a chat dictionary for a user message. Args: - input_message (str, optional): The user's message. Defaults to "". + input_message (str, optional): The user's message. Defaults to "". Returns: - dict: A dictionary formatted for the Gemini API. + dict: A dictionary formatted for the Gemini API. """ return {"role": "user", "parts": input_message} @@ -119,14 +119,15 @@ def convert_code_from_torch_to_jax(codeComponent, memory_list): Converts a single code component from PyTorch to JAX using the LLM agent. Args: - code_component (str): The Python code to be converted. - memory_list (list, optional): A list of previous chat messages to provide context. - Defaults to None. + code_component (str): The Python code to be converted. + memory_list (list, optional): A list of previous chat messages to provide context. + Defaults to None. Returns: - tuple: A tuple containing: - - str: The converted JAX code. - - list: The updated memory list. + tuple: A tuple containing + + * str: The converted JAX code. + * list: The updated memory list. """ if memory_list is None: memory_list = [] @@ -147,7 +148,7 @@ def parse_args(): Parses command-line arguments for file or folder processing. Returns: - argparse.Namespace: The parsed command-line arguments. + argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Code Conversion and Test Case Generation Agent") group = parser.add_mutually_exclusive_group(required=True) diff --git a/src/MaxText/experimental/agent/code_generation_agent/make_pytorch_file.py b/src/MaxText/experimental/agent/code_generation_agent/make_pytorch_file.py index e25dbe2030..a5b6c0aca2 100644 --- a/src/MaxText/experimental/agent/code_generation_agent/make_pytorch_file.py +++ b/src/MaxText/experimental/agent/code_generation_agent/make_pytorch_file.py @@ -55,16 +55,17 @@ def is_torch_function_or_class(node): Checks if an AST node represents a PyTorch-related function or class. This is determined by: - - A class inheriting from 'nn.Module'. - - A function having 'torch' in its annotations. - - A function body containing references to 'torch'. + + * A class inheriting from 'nn.Module'. + * A function having 'torch' in its annotations. + * A function body containing references to 'torch'. Args: - node: An AST node (ast.FunctionDef or ast.ClassDef). + node: An AST node (ast.FunctionDef or ast.ClassDef). Returns: - bool: True if the node is a PyTorch-related function or class, - False otherwise. + bool: True if the node is a PyTorch-related function or class, + False otherwise. """ if isinstance(node, ast.FunctionDef): # Look for 'torch' in annotations or function body @@ -93,10 +94,10 @@ def file_uses_torch(tree): Checks if a file's AST contains any top-level imports of the 'torch' module. Args: - tree: The AST of the entire file. + tree: The AST of the entire file. Returns: - bool: True if 'torch' is imported, False otherwise. + bool: True if 'torch' is imported, False otherwise. """ for node in ast.walk(tree): if isinstance(node, (ast.Import, ast.ImportFrom)): @@ -118,16 +119,15 @@ def has_external_dependencies(code, removed_names=None, local_components=None): filtered out. Args: - code (str): The source code of the component to check. - removed_names (list, optional): A list of names (functions, classes) - that were removed by `remove_local_imports`. - Defaults to None. - local_components (list, optional): A list of names of other components - in the same file. Defaults to None. + code (str): The source code of the component to check. + removed_names (list, optional): A list of names (functions, classes) + that were removed by `remove_local_imports`. Defaults to None. + local_components (list, optional): A list of names of other components + in the same file. Defaults to None. Returns: - bool: True if a dependency on a removed or local name is found, - False otherwise. + bool: True if a dependency on a removed or local name is found, False + otherwise. """ if not removed_names and not local_components: return False diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/database_operations.py b/src/MaxText/experimental/agent/integrative_rag_agent/database_operations.py index c46c54deb1..30cc6a20d8 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/database_operations.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/database_operations.py @@ -69,15 +69,15 @@ def save_document(name, text, desc, file, embedding): """Insert a document and its embedding into the database. Args: - name (str): Logical name/identifier for the document. - text (str): Raw text content of the document. - desc (str): Short description or summary of the document. - file (str): File path or source identifier for the document. - embedding (numpy.ndarray): Dense vector representation of the document - with shape (dim,) and dtype convertible to float32. + name (str): Logical name/identifier for the document. + text (str): Raw text content of the document. + desc (str): Short description or summary of the document. + file (str): File path or source identifier for the document. + embedding (numpy.ndarray): Dense vector representation of the document + with shape (dim,) and dtype convertible to float32. Returns: - None + None """ conn = sqlite3.connect(rag_db_file) cur = conn.cursor() @@ -95,12 +95,13 @@ def load_all_documents(): """Load all documents and embeddings from the database. Returns: - tuple[list[int], list[str], list[str], list[str], numpy.ndarray]: - - ids: Row IDs for each document. - - names: Names for each document. - - texts: Text content for each document. - - files: Source file paths/identifiers. - - embeddings: Array of shape (num_docs, dim) with dtype float32. + tuple[list[int], list[str], list[str], list[str], numpy.ndarray]: + + * ids: Row IDs for each document. + * names: Names for each document. + * texts: Text content for each document. + * files: Source file paths/identifiers. + * embeddings: Array of shape (num_docs, dim) with dtype float32. """ conn = sqlite3.connect(rag_db_file) cur = conn.cursor() @@ -123,10 +124,10 @@ def build_faiss_index(embeddings): """Build a FAISS IndexFlatL2 from document embeddings. Args: - embeddings (numpy.ndarray): Array of shape (num_docs, dim), dtype float32. + embeddings (numpy.ndarray): Array of shape (num_docs, dim), dtype float32. Returns: - faiss.IndexFlatL2 or None: L2 index with the provided vectors added, or + faiss.IndexFlatL2 or None: L2 index with the provided vectors added, or None if the embeddings array is empty or not 2-dimensional. """ if embeddings.ndim != 2 or embeddings.shape[0] == 0: @@ -141,14 +142,14 @@ def search_embedding(query_embedding, index, texts, top_k=3): """Search the index for nearest neighbors to a query embedding. Args: - query_embedding (array-like): Vector of shape (dim,) convertible to float32. - index (faiss.Index): A FAISS index built over document embeddings. - texts (list[str]): Texts aligned with vectors in the index. - top_k (int): Number of nearest neighbors to retrieve. + query_embedding (array-like): Vector of shape (dim,) convertible to float32. + index (faiss.Index): A FAISS index built over document embeddings. + texts (list[str]): Texts aligned with vectors in the index. + top_k (int): Number of nearest neighbors to retrieve. Returns: - list[tuple[str, float, int]]: For each neighbor, a tuple of (text, distance, index_in_corpus). - Distances are squared L2 (Euclidean) norms; smaller values indicate greater similarity. + list[tuple[str, float, int]]: For each neighbor, a tuple of (text, distance, index_in_corpus). + Distances are squared L2 (Euclidean) norms; smaller values indicate greater similarity. """ if index is None: return [] @@ -163,8 +164,8 @@ def make_embedding_index(): """Load all documents and build a FAISS index over their embeddings. Returns: - tuple[list[int], list[str], list[str], list[str], faiss.Index]: - (ids, names, texts, files, index) + tuple[list[int], list[str], list[str], list[str], faiss.Index]: + (ids, names, texts, files, index) """ ids, names, texts, files, embeddings = load_all_documents() index = build_faiss_index(embeddings) diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/get_model_info.py b/src/MaxText/experimental/agent/integrative_rag_agent/get_model_info.py index 60feaefb09..c13e858b17 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/get_model_info.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/get_model_info.py @@ -38,13 +38,14 @@ def get_model_info(model_id: str): `transformers` package, and the model type token. Args: - model_id (str): A Hugging Face model identifier (e.g., "meta-llama/Llama-3-8B"). + model_id (str): A Hugging Face model identifier (e.g., "meta-llama/Llama-3-8B"). Returns: - dict: A dictionary with keys: - - "class_name" (str | None): The architecture class name, if available. - - "file_path" (str): The canonical `transformers` path to the modeling file. - - "model_type" (str): The model type string from the config. + dict: A dictionary with keys + + * "class_name" (str | None): The architecture class name, if available. + * "file_path" (str): The canonical `transformers` path to the modeling file. + * "model_type" (str): The model type string from the config. """ # Load config only (very lightweight, no weights) config = AutoConfig.from_pretrained(model_id) @@ -64,7 +65,7 @@ def parse_args(): Parses command-line arguments for file or folder processing. Returns: - argparse.Namespace: The parsed command-line arguments. + argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Get model info from HuggingFace model id") parser.add_argument( diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_code_conversion.py b/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_code_conversion.py index de88ebbfb5..3b13fe65cc 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_code_conversion.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_code_conversion.py @@ -59,11 +59,12 @@ def arg_parser(): """Create and return the CLI argument parser for code conversion. Returns: - argparse.Namespace: Parsed arguments containing: - - number_of_maxtext_blocks (int) - - module_name (str) - - destination_base_directory (str) - - destination_source_url (str) + argparse.Namespace: Parsed arguments containing + + * number_of_maxtext_blocks (int) + * module_name (str) + * destination_base_directory (str) + * destination_source_url (str) """ parser = argparse.ArgumentParser(description="LLM code conversion utility.") parser.add_argument("--number-of-maxtext-blocks", type=int, default=5, help="Number of maxtext blocks to process.") @@ -100,7 +101,7 @@ def get_exisiting_jax_modules(): values are textual analyses used as LLM guidance. Returns: - dict[str, str]: Mapping of module key to analysis/description. + dict[str, str]: Mapping of module key to analysis/description. """ with open(maxtext_block_description, "rt", encoding="utf-8") as f: module_list = json.load(f) @@ -176,14 +177,14 @@ def convert_given_file(module, jax_modules) -> None | dict: returns a description for the generated module. Args: - module (dict): Component metadata with keys like "filepath", - "comp_name", and optional "JaxDependencies". - jax_modules (dict): Existing JAX module descriptions to provide context - to the LLM. + module (dict): Component metadata with keys like "filepath", + "comp_name", and optional "JaxDependencies". + jax_modules (dict): Existing JAX module descriptions to provide context + to the LLM. Returns: - dict | None: Mapping of fully-qualified package name to generated - description, or None if generation did not produce a detectable module. + dict | None: Mapping of fully-qualified package name to generated + description, or None if generation did not produce a detectable module. """ maxtext_blocks_code = read_code_blocks(maxtext_code_block, args.number_of_maxtext_blocks) module_code, file_code = get_modules_from_file(destination_source_url + module["filepath"], module=module["comp_name"]) diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_embedding_generation.py b/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_embedding_generation.py index efac2c137a..41a54ab33b 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_embedding_generation.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/llm_rag_embedding_generation.py @@ -173,11 +173,12 @@ def get_code_description_with_gemini(code_block, full_code_context, user_prompt= Analyzes a Python code block using the Gemini API to generate a structured description. Args: - code_block (str): The specific Python function or class to analyze. - full_code_context (str): The full source code of the file for context. - user_prompt (str): The prompt template for the user message. + code_block (str): The specific Python function or class to analyze. + full_code_context (str): The full source code of the file for context. + user_prompt (str): The prompt template for the user message. + Returns: - None | dict: A dictionary containing the structured analysis, or an error message and return `None`. + None | dict: A dictionary containing the structured analysis, or an error message and return `None`. """ llm_agent = GeminiAgent(system_instruction=Description_Prompt) resp = None diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/scrap_all_python_blocks.py b/src/MaxText/experimental/agent/integrative_rag_agent/scrap_all_python_blocks.py index 5c65ffd07b..576d38cea1 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/scrap_all_python_blocks.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/scrap_all_python_blocks.py @@ -46,11 +46,11 @@ def scrape_python_blocks(source_code, file_path_for_logging): any blocks nested within functions. Args: - source_code (str): The Python source code as a string. - file_path_for_logging (str): The path of the file being scraped, for logging purposes. + source_code (str): The Python source code as a string. + file_path_for_logging (str): The path of the file being scraped, for logging purposes. Returns: - list: A list of strings, where each string is a source code block. + list: A list of strings, where each string is a source code block. """ blocks = [] try: @@ -128,15 +128,16 @@ def find_and_scrape_from_github(owner, repo, paths, token=None): Finds Python files in GitHub paths and scrapes their blocks and full code. Args: - owner (str): The owner of the GitHub repository. - repo (str): The name of the GitHub repository. - paths (list): A list of file or directory paths within the repo. - token (str, optional): A GitHub Personal Access Token for authentication. + owner (str): The owner of the GitHub repository. + repo (str): The name of the GitHub repository. + paths (list): A list of file or directory paths within the repo. + token (str, optional): A GitHub Personal Access Token for authentication. Returns: - tuple: A tuple containing two dictionaries: - - A dictionary of scraped code blocks. - - A dictionary of the full source code for each file. + tuple: A tuple containing two dictionaries + + * A dictionary of scraped code blocks. + * A dictionary of the full source code for each file. """ all_scraped_blocks = {} all_full_codes = {} diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/sort_components_in_hierarchical_order.py b/src/MaxText/experimental/agent/integrative_rag_agent/sort_components_in_hierarchical_order.py index 31dcb42fb1..f2cbd0f5c2 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/sort_components_in_hierarchical_order.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/sort_components_in_hierarchical_order.py @@ -84,10 +84,11 @@ def search_similar_dependency(depend, base_path, project_root): project_root (str): The root directory of the project. Returns: - tuple: A tuple containing: - - A list of tuples, where each tuple contains (distance, name, file, code_text) - for similar dependencies found. Returns None if no similar dependencies are found. - - The code of the original module. + tuple: A tuple containing + + * A list of tuples, where each tuple contains (distance, name, file, code_text) + for similar dependencies found. Returns None if no similar dependencies are found. + * The code of the original module. """ if enable_cache: if os.path.exists(torch_jax_similar_dependency_cache_file): @@ -252,10 +253,10 @@ def load_status(file_path: str) -> Status: (like dequeues and sets) to allow for the continuation of a paused analysis. Args: - file_path (str): The path to the JSON file containing the saved status. + file_path (str): The path to the JSON file containing the saved status. Returns: - tuple: A tuple of restored variables representing the analysis state. + tuple: A tuple of restored variables representing the analysis state. """ with open(file_path, "rt", encoding="utf-8") as f: status = json.load(f) diff --git a/src/MaxText/experimental/agent/integrative_rag_agent/utils.py b/src/MaxText/experimental/agent/integrative_rag_agent/utils.py index 61df188a2a..875a30daa3 100644 --- a/src/MaxText/experimental/agent/integrative_rag_agent/utils.py +++ b/src/MaxText/experimental/agent/integrative_rag_agent/utils.py @@ -32,7 +32,7 @@ def read_code_blocks(file_path, number_of_blocks): Returns: str: A string containing the randomly selected code blocks, separated by - three newlines. + three newlines. """ with open(file_path, "rt", encoding="utf-8") as f: all_blocks = json.load(f) diff --git a/src/MaxText/experimental/agent/orchestration_agent/get_files_in_hierarchical_order.py b/src/MaxText/experimental/agent/orchestration_agent/get_files_in_hierarchical_order.py index 742eb1008c..b1bf80fd32 100644 --- a/src/MaxText/experimental/agent/orchestration_agent/get_files_in_hierarchical_order.py +++ b/src/MaxText/experimental/agent/orchestration_agent/get_files_in_hierarchical_order.py @@ -53,14 +53,13 @@ def find_file_dependencies(file_path_url, base_path_url, exclude_conditional_imp Finds all direct Python file dependencies for a given file. Args: - file_path_url (str): The full GitHub URL of the Python file to analyze. - base_path_url (str): The base URL of the GitHub repository's source directory. - exclude_conditional_imports (bool): If True, imports inside functions, - classes, or `if TYPE_CHECKING:` blocks - are ignored. + file_path_url (str): The full GitHub URL of the Python file to analyze. + base_path_url (str): The base URL of the GitHub repository's source directory. + exclude_conditional_imports (bool): If True, imports inside functions, + classes, or `if TYPE_CHECKING:` blocks are ignored. Returns: - set: A set of full GitHub URLs of the dependent Python files. + set: A set of full GitHub URLs of the dependent Python files. """ dependencies = set() flag, content = get_github_file_content(file_path_url) @@ -122,19 +121,19 @@ def get_dependency_sorted_files(entry_file_path, base_path, exclude_conditional_ topologically sorted list of all dependent Python files. Args: - entry_file_path (str): The full GitHub URL of the entry Python file. - base_path (str): The base URL of the GitHub repository's source directory. - exclude_conditional_imports (bool): If True, imports inside functions, - classes, or `if TYPE_CHECKING:` blocks - are ignored for dependency analysis. - returnDependencies (bool): If True, returns a tuple of (sorted_files, - dependency_graph), otherwise just sorted_files. + entry_file_path (str): The full GitHub URL of the entry Python file. + base_path (str): The base URL of the GitHub repository's source directory. + exclude_conditional_imports (bool): If True, imports inside functions, + classes, or `if TYPE_CHECKING:` blocks are ignored for dependency + analysis. + returnDependencies (bool): If True, returns a tuple of (sorted_files, + dependency_graph), otherwise just sorted_files. Returns: - list: A list of file paths (relative to base_path) in topological order. - Returns an empty list if a circular dependency is detected. - dict (optional): A dictionary representing the dependency graph if - returnDependencies is True. + list: A list of file paths (relative to base_path) in topological order. + Returns an empty list if a circular dependency is detected. + dict (optional): A dictionary representing the dependency graph if + returnDependencies is True. """ dependency_graph = {} reverse_graph = {} @@ -219,7 +218,7 @@ def parse_args(): Parses command-line arguments for file or folder processing. Returns: - argparse.Namespace: The parsed command-line arguments. + argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Dependency sorter for Python files on GitHub.") parser.add_argument( diff --git a/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py b/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py index 955c76643d..df338674fd 100644 --- a/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py +++ b/src/MaxText/experimental/agent/orchestration_agent/split_python_file.py @@ -170,12 +170,12 @@ def get_source_code(self): a local file, reads from disk. Returns: - str: The file contents. + str: The file contents. Raises: - FileNotFoundError: When the remote file does not exist or a local - path is missing. - IOError: When a remote file exists but cannot be read. + FileNotFoundError: When the remote file does not exist or a local + path is missing. + IOError: When a remote file exists but cannot be read. """ source_code = "" if self.file_path.startswith("https"): @@ -202,10 +202,10 @@ def convert_package_to_path(self, path): "utils": "src/MaxText/inference.py#utils"} Args: - path (str): A normalized absolute import string. + path (str): A normalized absolute import string. Returns: - dict[str, str]: Mapping of imported names to "file.py#name" anchors. + dict[str, str]: Mapping of imported names to "file.py#name" anchors. """ path_form, path_imports = path.removeprefix("from ").replace(".", os.path.sep).split(" import ") import_dict = {} @@ -548,7 +548,7 @@ def load_cache(self): """Load cached analysis result if caching is enabled. Returns: - tuple[str|None, dict]: A `(cache_key, search_cache)` pair. When + tuple[str|None, dict]: A `(cache_key, search_cache)` pair. When caching is disabled, returns `(None, {})`. """ search_cache = {} @@ -573,10 +573,11 @@ def get_sorted_structure(self) -> SortedStructure: """Compute (or load) and return the full sorted structure for the file. Returns: - dict: A dictionary with keys: - - "sorted_modules": Mapping of component name to source code. - - "component_dependencies": Adjacency lists by component. - - "warning": Optional warning message about cycles. + dict: A dictionary with keys + + * "sorted_modules": Mapping of component name to source code. + * "component_dependencies": Adjacency lists by component. + * "warning": Optional warning message about cycles. """ cache_key, search_cache = self.load_cache() if cache_key is not None and cache_key in search_cache: @@ -606,13 +607,13 @@ def get_module_code(self, module_name): Returns the source code for a given module/component name. Parameters: - module_name (str): The name of the module/component to retrieve. + module_name (str): The name of the module/component to retrieve. Returns: - str: The source code of the requested module/component. + str: The source code of the requested module/component. Raises: - KeyError: If no component with the provided name exists. + KeyError: If no component with the provided name exists. """ # Ensure analysis has been done if not hasattr(self, "sorted_components"): @@ -661,7 +662,7 @@ def parse_args(): Parses command-line arguments for file or folder processing. Returns: - argparse.Namespace: The parsed command-line arguments. + argparse.Namespace: The parsed command-line arguments. """ parser = argparse.ArgumentParser(description="Analyze Python file dependencies and split into components.") parser.add_argument( diff --git a/src/MaxText/experimental/agent/orchestration_agent/utils.py b/src/MaxText/experimental/agent/orchestration_agent/utils.py index d962b233f2..4d67d9e44a 100644 --- a/src/MaxText/experimental/agent/orchestration_agent/utils.py +++ b/src/MaxText/experimental/agent/orchestration_agent/utils.py @@ -33,10 +33,10 @@ def github_blob_to_raw(blob_url): Converts a GitHub blob URL to its raw content URL. Args: - blob_url (str): The URL of the GitHub blob. + blob_url (str): The URL of the GitHub blob. Returns: - str: The URL of the raw content of the GitHub blob. + str: The URL of the raw content of the GitHub blob. """ parsed = urlparse(blob_url) if "github.com" not in parsed.netloc or "/blob/" not in parsed.path: @@ -60,11 +60,11 @@ def check_github_file_exists(blob_url): Checks if a file exists on GitHub given its blob URL. Args: - blob_url (str): The URL of the GitHub blob. + blob_url (str): The URL of the GitHub blob. Returns: - bool: True if the file exists, False otherwise. - str: The raw URL of the file if it exists, or an error message. + bool: True if the file exists, False otherwise. + str: The raw URL of the file if it exists, or an error message. """ raw_url = github_blob_to_raw(blob_url) @@ -82,12 +82,13 @@ def get_github_file_content(blob_url): Retrieves the content of a file from GitHub given its blob URL. Args: - blob_url (str): The URL of the GitHub blob. + blob_url (str): The URL of the GitHub blob. Returns: - tuple: A tuple containing: - - bool: True if the content was retrieved successfully, False otherwise. - - str: The content of the file if successful, or an error message if not. + tuple: A tuple containing + + * bool: True if the content was retrieved successfully, False otherwise. + * str: The content of the file if successful, or an error message if not. """ exists, raw_url_or_error = check_github_file_exists(blob_url) @@ -112,8 +113,8 @@ def check_if_file_exists(url): Returns: tuple[bool, str]: - - True with the resolved raw URL/path if the file exists - - False with an error message otherwise + * True with the resolved raw URL/path if the file exists + * False with an error message otherwise """ if "http" in url and "github.com" in url: return check_github_file_exists(url) @@ -131,8 +132,8 @@ def get_file_content(url): Returns: tuple[bool, str]: - - (True, contents) on success - - (False, error_message) on failure + * (True, contents) on success + * (False, error_message) on failure """ if "http" in url and "github.com" in url: return get_github_file_content(url) @@ -167,11 +168,11 @@ def find_cycle(graph): Finds a cycle in a directed graph using DFS. Args: - graph (dict): A dictionary representing the graph where keys are nodes - and values are lists of their direct dependencies. + graph (dict): A dictionary representing the graph where keys are nodes + and values are lists of their direct dependencies. Returns: - list: A list of nodes forming a cycle if one is found, otherwise None. + list: A list of nodes forming a cycle if one is found, otherwise None. """ visited = set() @@ -216,15 +217,16 @@ def remove_local_imports(source_code, filepath=None): would be removed). Args: - source_code (str): The Python source code as a string. - filepath (str, optional): The path to the file containing the source - code. Used to determine the base module for - identifying local imports. Defaults to None. + source_code (str): The Python source code as a string. + filepath (str, optional): The path to the file containing the source code. + Used to determine the base module for identifying local imports. Defaults + to None. Returns: - tuple: A tuple containing: - - str: The modified source code with local imports removed. - - str: A newline-separated string of the names of the removed imports. + tuple: A tuple containing + + * str: The modified source code with local imports removed. + * str: A newline-separated string of the names of the removed imports. """ if filepath is not None: # Determine the base module from the filepath @@ -333,11 +335,11 @@ def parse_python_code(code): It supports both '```python' and generic '```' delimiters. Args: - code (str): The input string potentially containing Python code blocks. + code (str): The input string potentially containing Python code blocks. Returns: - str: The extracted Python code. Returns the original string if no - code blocks are found. + str: The extracted Python code. Returns the original string if no + code blocks are found. """ if "```python" in code: code = code.split("```python")[1] @@ -353,12 +355,12 @@ def have_module(target_name, file_url): Checks if a given module (function, class, or variable) exists in a Python file. Args: - target_name (str): The name of the module to search for. - file_url (str): The URL of the Python file to check. + target_name (str): The name of the module to search for. + file_url (str): The URL of the Python file to check. Returns: - bool: True if the module is found, False otherwise. - tuple: ("ImportFrom", full_module) if the target_name is an alias from an import statement. + bool: True if the module is found, False otherwise. + tuple: ("ImportFrom", full_module) if the target_name is an alias from an import statement. """ flag, content = get_file_content(file_url) if not flag: @@ -392,26 +394,26 @@ def resolve_complex_import(module_path_base_url, importPackage, base_url, curren refer to a file or a directory (package) with an __init__.py. Args: - module_path_base_url (str): The base URL for the module path - (ex. 'https://github.com/.../transformers/models/llama'). - importPackage (str): The specific name being imported (e.g., 'modeling_llama', 'configuration_llama'). - base_url (str): The base URL of the repository. - current_dir_url (str): The URL of the directory containing the original import statement. - num_try (int): Counter for recursion depth. - Message (str): Accumulates error messages for recursion depth. + module_path_base_url (str): The base URL for the module path + (ex. 'https://github.com/.../transformers/models/llama'). + importPackage (str): The specific name being imported (e.g., 'modeling_llama', 'configuration_llama'). + base_url (str): The base URL of the repository. + current_dir_url (str): The URL of the directory containing the original import statement. + num_try (int): Counter for recursion depth. + Message (str): Accumulates error messages for recursion depth. Returns: - str: The resolved full GitHub URL of the imported file, or None if not found. + str: The resolved full GitHub URL of the imported file, or None if not found. Example: - If `module_path_base_url` is "https://github.com/org/repo/blob/main/src/transformers/models/llama", - `importPackage` is "modeling_llama", `base_url` is "https://github.com/org/repo/blob/main/src/", - and `current_dir_url` is "https://github.com/org/repo/blob/main/src/transformers/models/llama", - this function would first check for "https://github.com/org/repo/blob/main/src/transformers/models/llama.py". - If not found, it would then check for - "https://github.com/org/repo/blob/main/src/transformers/models/llama/__init__.py". - If `__init__.py` exists and contains `from . import modeling_llama`, it would then check for - "https://github.com/org/repo/blob/main/src/transformers/models/llama/modeling_llama.py". + If `module_path_base_url` is "https://github.com/org/repo/blob/main/src/transformers/models/llama", + `importPackage` is "modeling_llama", `base_url` is "https://github.com/org/repo/blob/main/src/", + and `current_dir_url` is "https://github.com/org/repo/blob/main/src/transformers/models/llama", + this function would first check for "https://github.com/org/repo/blob/main/src/transformers/models/llama.py". + If not found, it would then check for + "https://github.com/org/repo/blob/main/src/transformers/models/llama/__init__.py". + If `__init__.py` exists and contains `from . import modeling_llama`, it would then check for + "https://github.com/org/repo/blob/main/src/transformers/models/llama/modeling_llama.py". """ message = "" @@ -449,14 +451,14 @@ def resolve_import_path(importer_url, module_name, level, base_url, importPackag Resolves an import statement to a full GitHub URL. Args: - importer_url (str): The URL of the file containing the import statement. - module_name (str): The name of the module being imported (e.g., 'os', 'transformers.models.llama'). - level (int): The level of the import (0 for absolute, 1+ for relative). - base_url (str): The base URL of the repository (e.g., 'https://github.com/huggingface/transformers/blob/main/src/'). - importPackage (str, optional): The specific package or module being imported from a 'from ... import ...' statement. + importer_url (str): The URL of the file containing the import statement. + module_name (str): The name of the module being imported (e.g., 'os', 'transformers.models.llama'). + level (int): The level of the import (0 for absolute, 1+ for relative). + base_url (str): The base URL of the repository (e.g., 'https://github.com/huggingface/transformers/blob/main/src/'). + importPackage (str, optional): The specific package or module being imported from a 'from ... import ...' statement. Returns: - str: The resolved full GitHub URL of the imported file, or None if not found. + str: The resolved full GitHub URL of the imported file, or None if not found. """ current_dir_url = importer_url[: importer_url.rfind("/")] if level > 0: # Relative import @@ -497,7 +499,7 @@ def get_absolute_imports(import_line, file_url, project_root="transformers"): Returns: str | None: The converted absolute import line(s), the original line if - unchanged, or None if resolution failed. + unchanged, or None if resolution failed. """ if not import_line.startswith("from "): return import_line diff --git a/src/MaxText/experimental/agent/self_debugging_agent/self_debugging_agent.py b/src/MaxText/experimental/agent/self_debugging_agent/self_debugging_agent.py index ec751f2f81..807d6225c5 100644 --- a/src/MaxText/experimental/agent/self_debugging_agent/self_debugging_agent.py +++ b/src/MaxText/experimental/agent/self_debugging_agent/self_debugging_agent.py @@ -90,14 +90,15 @@ def get_file_pairs(module_name, pytorch_path, jax_path): creates corresponding file paths for the JAX directory. Args: - module_name: Iterable of module name - pytorch_path: The path to the directory containing PyTorch files. - jax_path: The path to the directory where JAX files will be stored. + module_name: Iterable of module name + pytorch_path: The path to the directory containing PyTorch files. + jax_path: The path to the directory where JAX files will be stored. Returns: - A tuple containing two lists of strings: - - The first list contains the full paths to the PyTorch files. - - The second list contains the corresponding full paths for the JAX files. + A tuple containing two lists of strings + + * The first list contains the full paths to the PyTorch files. + * The second list contains the corresponding full paths for the JAX files. """ pytorch_files = list(filter(lambda x: x.endswith(".py"), os.listdir(pytorch_path))) if module_name is not None: @@ -116,15 +117,15 @@ def generate_test_case(python_file, entry_module, python_code, jax_code, jax_fil then saved to a specified file path. Args: - python_file: The path to the original PyTorch code file. - entry_module: The name of the main module (function or class) to be tested. - python_code: The content of the PyTorch code file. - jax_code: The content of the JAX code file. - jax_file: The path where the JAX code file is or will be saved. - test_file_path: The path where the generated test case should be saved. + python_file: The path to the original PyTorch code file. + entry_module: The name of the main module (function or class) to be tested. + python_code: The content of the PyTorch code file. + jax_code: The content of the JAX code file. + jax_file: The path where the JAX code file is or will be saved. + test_file_path: The path where the generated test case should be saved. Returns: - The generated test case code as a string. + The generated test case code as a string. """ prompt = CodeEvaluation["TESTCASE"] python_code = ( @@ -157,18 +158,19 @@ def save_and_run_test_case(jax_code, test_code, jax_file, test_file_path): and test results. Args: - jax_code: The JAX code to be saved. - test_code: The test case code to be saved. - jax_file: The path to the file where the JAX code will be written. - test_file_path: The path to the file where the test case will be written. + jax_code: The JAX code to be saved. + test_code: The test case code to be saved. + jax_file: The path to the file where the JAX code will be written. + test_file_path: The path to the file where the test case will be written. Returns: - A tuple containing: - - The captured output from the pytest execution. - - The exit code of the pytest process. - - A boolean indicating if a dependency error occurred. - - The number of passed tests. - - The number of failed tests. + A tuple containing + + * The captured output from the pytest execution. + * The exit code of the pytest process. + * A boolean indicating if a dependency error occurred. + * The number of passed tests. + * The number of failed tests. """ with open(jax_file, "wt", encoding="utf-8") as f: f.write(jax_code) @@ -188,20 +190,21 @@ def code_debugging(args, python_file, jax_file, test_file_path, last_output, cod multiple times until the tests pass or the retry limit is reached. Args: - c python_file: The path to the PyTorch reference code. - jax_file: The path to the JAX code file being debugged. - test_file_path: The path to the test case file. - last_output: The output from the last failed test run (stack trace). - code_history: A list of dictionaries containing previous code states and - test results. - base_try: The current attempt number for debugging. + python_file: The path to the PyTorch reference code. + jax_file: The path to the JAX code file being debugged. + test_file_path: The path to the test case file. + last_output: The output from the last failed test run (stack trace). + code_history: A list of dictionaries containing previous code states and + test results. + base_try: The current attempt number for debugging. Returns: - A tuple containing: - - An integer exit code (0 for success, 1 for failure). - - The number of passed tests from the final attempt. - - The number of failed tests from the final attempt. - - The updated code history list. + A tuple containing + + * An integer exit code (0 for success, 1 for failure). + * The number of passed tests from the final attempt. + * The number of failed tests from the final attempt. + * The updated code history list. """ try: memory_list = [] @@ -266,13 +269,13 @@ def make_code_and_debug(args, python_file, jax_file): debugging attempts. Args: - args (argparse.Namespace): CLI arguments - python_file: The path to the PyTorch code file. - jax_file: The path where the generated JAX code will be stored. + args (argparse.Namespace): CLI arguments + python_file: The path to the PyTorch code file. + jax_file: The path where the generated JAX code will be stored. Returns: - A tuple containing the number of passed and failed test cases - from the final successful or best-effort attempt. + A tuple containing the number of passed and failed test cases + from the final successful or best-effort attempt. """ assert os.path.exists(args.pytorch_path), f"python file {python_file} not exists" try: diff --git a/src/MaxText/experimental/agent/self_debugging_agent/utils.py b/src/MaxText/experimental/agent/self_debugging_agent/utils.py index 39200b68c5..8252a358a6 100644 --- a/src/MaxText/experimental/agent/self_debugging_agent/utils.py +++ b/src/MaxText/experimental/agent/self_debugging_agent/utils.py @@ -27,13 +27,14 @@ def check_code_syntax(file_path: str): If an error occurs during compilation, it means there is a syntax error. Args: - file_path: The path to the Python file to be checked. + file_path: The path to the Python file to be checked. Returns: - A tuple containing: - - An integer exit code (0 for success, 1 for error). - - A string message indicating the result (e.g., "Syntax OK" - or a detailed error message). + A tuple containing + + * An integer exit code (0 for success, 1 for error). + * A string message indicating the result (e.g., "Syntax OK" + or a detailed error message). """ try: py_compile.compile(file_path, doraise=True) @@ -50,13 +51,14 @@ def save_in_file_and_check_code_syntax(code, file_path): to validate the syntax of the newly written file. Args: - code: A string containing the Python code to be saved and checked. - file_path: The path where the code should be saved. + code: A string containing the Python code to be saved and checked. + file_path: The path where the code should be saved. Returns: - A tuple containing: - - An integer exit code (0 for success, 1 for error). - - A string message indicating the result. + A tuple containing + + * An integer exit code (0 for success, 1 for error). + * A string message indicating the result. """ with open(file_path, "wt", encoding="utf-8") as f: f.write(code) @@ -72,11 +74,11 @@ def parse_json_response(response): of that block as a JSON dictionary. Args: - Response: The string containing the JSON code block. + Response: The string containing the JSON code block. Returns: - A dictionary containing the parsed JSON data. Returns an empty - dictionary if no JSON block is found or if parsing fails. + A dictionary containing the parsed JSON data. Returns an empty + dictionary if no JSON block is found or if parsing fails. """ response_dict = {} if "```json" in response: @@ -97,15 +99,15 @@ def smartly_copy_code(filename, base_jax_path, base_testcase_path, dest_jax_path to reflect the new location. Args: - filename: Name of the Python file to be copied. - base_jax_path: Path to the source JAX module directory. - base_testcase_path: Path to the source test files directory. - dest_jax_path: Path to the destination JAX module directory. - dest_testcase_path: Path to the destination test files directory. + filename: Name of the Python file to be copied. + base_jax_path: Path to the source JAX module directory. + base_testcase_path: Path to the source test files directory. + dest_jax_path: Path to the destination JAX module directory. + dest_testcase_path: Path to the destination test files directory. Returns: - True if both the module file and the test file exist in the - destination after copying, otherwise False. + True if both the module file and the test file exist in the + destination after copying, otherwise False. """ base_jax_package = base_jax_path.removeprefix("../").removeprefix("./").replace(os.path.sep, ".") dest_jax_package = dest_jax_path.removeprefix("../").removeprefix("./").replace(os.path.sep, ".") diff --git a/src/MaxText/experimental/rl/grpo_trainer.py b/src/MaxText/experimental/rl/grpo_trainer.py index 926f1e1b67..7adde4ed5d 100644 --- a/src/MaxText/experimental/rl/grpo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_trainer.py @@ -114,9 +114,10 @@ def _split_grpo_state(state): key within its `params` attribute. Returns: - A tuple containing: - - new_state: The training state with 'reference_params' removed. - - reference_params: The extracted reference parameters. + A tuple containing + + * new_state: The training state with 'reference_params' removed. + * reference_params: The extracted reference parameters. """ reference_params = state.params["reference_params"] new_state = state.replace(params={k: v for k, v in state.params.items() if k != "reference_params"}) @@ -359,10 +360,11 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat dropout_rng: JAX PRNG key for dropout. Returns: - A tuple containing: - - new_state: The updated training state after applying gradients. - - metrics: A dictionary of metrics for logging, including loss, reward, - and gradient norms. + A tuple containing + + * new_state: The updated training state after applying gradients. + * metrics: A dictionary of metrics for logging, including loss, reward, + and gradient norms. """ state, reference_params = _split_grpo_state(state) state_mesh_shardings, reference_params_sharding = _split_grpo_state(state_mesh_shardings) @@ -529,19 +531,20 @@ def setup_train_loop( recorder: A GoodputRecorder for performance tracking. Returns: - A tuple containing: - - init_rng: The initial JAX PRNG key. - - checkpoint_manager: The Orbax checkpoint manager. - - state_mesh_shardings: Sharding specifications for the training state. - - inference_state_mesh_shardings: Sharding specs for the inference state. - - model: The training model instance. - - inference_model: The inference model instance. - - mesh: The device mesh for training. - - inference_mesh: The device mesh for inference. - - learning_rate_schedule: The learning rate schedule function. - - data_iterator: The iterator for the input prompt dataset. - - eval_data_iterator: The iterator for the evaluation dataset (or None). - - state: The initialized training state. + A tuple containing + + * init_rng: The initial JAX PRNG key. + * checkpoint_manager: The Orbax checkpoint manager. + * state_mesh_shardings: Sharding specifications for the training state. + * inference_state_mesh_shardings: Sharding specs for the inference state. + * model: The training model instance. + * inference_model: The inference model instance. + * mesh: The device mesh for training. + * inference_mesh: The device mesh for inference. + * learning_rate_schedule: The learning rate schedule function. + * data_iterator: The iterator for the input prompt dataset. + * eval_data_iterator: The iterator for the evaluation dataset (or None). + * state: The initialized training state. """ with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") diff --git a/src/MaxText/experimental/rl/grpo_utils.py b/src/MaxText/experimental/rl/grpo_utils.py index fb6b748a5c..9653544d16 100644 --- a/src/MaxText/experimental/rl/grpo_utils.py +++ b/src/MaxText/experimental/rl/grpo_utils.py @@ -69,11 +69,12 @@ def compute_log_probs( rngs: JAX PRNG keys for dropout. Returns: - A tuple containing: - - token_log_probs: A [B, L-1] array of log-probabilities for each token - in the completion. - - intermediate_outputs: A dictionary of intermediate activations from the - model. + A tuple containing + + * token_log_probs: A [B, L-1] array of log-probabilities for each token + in the completion. + * intermediate_outputs: A dictionary of intermediate activations from the + model. """ if not is_train: params = jax.lax.stop_gradient(params) diff --git a/src/MaxText/gradient_accumulation.py b/src/MaxText/gradient_accumulation.py index 20844ce38e..7a5f7d2406 100644 --- a/src/MaxText/gradient_accumulation.py +++ b/src/MaxText/gradient_accumulation.py @@ -57,10 +57,11 @@ def gradient_accumulation_loss_and_grad( extra_dpo_args: A tuple of extra arguments to pass to the loss function. Returns: - A tuple containing: - - total_loss (Array): The mean loss, averaged over all microbatches. - - final_aux (PyTree): Auxiliary outputs, summed across microbatches. - - raw_grads (PyTree): The accumulated and averaged gradients. + A tuple containing + + * total_loss (Array): The mean loss, averaged over all microbatches. + * final_aux (PyTree): Auxiliary outputs, summed across microbatches. + * raw_grads (PyTree): The accumulated and averaged gradients. """ def _maybe_shard_with_name(inputs, sharding_names): diff --git a/src/MaxText/inference/offline_engine.py b/src/MaxText/inference/offline_engine.py index e04067056b..a7e655e2ca 100644 --- a/src/MaxText/inference/offline_engine.py +++ b/src/MaxText/inference/offline_engine.py @@ -417,10 +417,10 @@ def _init_engine(self, params): """Initialize the MaxEngine. Args: - params: Model parameters + params: Model parameters Returns: - tuple of (params, engine) + tuple of (params, engine) """ start_time = time.time() engine = MaxEngine(self.config, self.devices) @@ -432,7 +432,7 @@ def _init_tokenizer(self): """Initialize the tokenizer. Returns: - Initialized tokenizer + Initialized tokenizer """ if self.eos_ids is None and self.tokenizer is None: tokenizer_params = self.engine.get_tokenizer() @@ -726,13 +726,13 @@ def emit_token( determines if generation should terminate. Args: - prompt_id: ID of the prompt - result_token: Token to emit - log_prob: Log probability of the token - prompt_logp: Log probabilities for the prompt tokens + prompt_id: ID of the prompt + result_token: Token to emit + log_prob: Log probability of the token + prompt_logp: Log probabilities for the prompt tokens Returns: - True if this token signals the end of generation, False otherwise + True if this token signals the end of generation, False otherwise """ # Skip if sequence already completed if prompt_id in self.completed_sequences: @@ -857,13 +857,13 @@ def batch_inference( """Run inference on a batch of inputs. Args: - data: list of InputData objects, or JAX or numpy arrays. - If input is JAX or numpy array, it must not contain padding tokens. - desc: Description string for logging - rng: Random number generator key. If None, the previous key will be used. + data: list of InputData objects, or JAX or numpy arrays. + If input is JAX or numpy array, it must not contain padding tokens. + desc: Description string for logging + rng: Random number generator key. If None, the previous key will be used. Returns: - list of CompletionOutput objects, one for each input in data + list of CompletionOutput objects, one for each input in data """ data = self.prepare_data(data) @@ -873,10 +873,10 @@ def prepare_data(self, data: list[InputData | jax.Array | np.ndarray]) -> list[I """Pad and if batch prefill is enabled, sort data by length. Args: - data: list of InputData objects, or JAX or numpy arrays + data: list of InputData objects, or JAX or numpy arrays Returns: - list of prepared InputData objects + list of prepared InputData objects """ # Convert JAX arrays to numpy arrays if isinstance(data[0], jax.Array): @@ -906,10 +906,10 @@ def pad_data(self, data: list[InputData]) -> list[InputData]: that is greater than or equal to its true length. Args: - data: list of InputData objects + data: list of InputData objects Returns: - list of padded InputData objects + list of padded InputData objects """ padded_data = [] diff --git a/src/MaxText/inference/paged_attention.py b/src/MaxText/inference/paged_attention.py index 3698011c07..4828b6e516 100644 --- a/src/MaxText/inference/paged_attention.py +++ b/src/MaxText/inference/paged_attention.py @@ -391,12 +391,14 @@ def __call__( page_state: The current state of the page manager. Returns: - A tuple (output, exponentials_max, exponentials_sum) containing: - - The attention output tensor. - - The max of the exponentials (for prefill mode with dot-product attention). - - The sum of the exponentials (for prefill mode with dot-product attention). - The latter two are None for autoregressive mode, as this is handled - internally by the paged attention kernel. + A tuple (output, exponentials_max, exponentials_sum) containing + + * The attention output tensor. + * The max of the exponentials (for prefill mode with dot-product attention). + * The sum of the exponentials (for prefill mode with dot-product attention). + + The latter two are None for autoregressive mode, as this is handled + internally by the paged attention kernel. """ key_pages_cache, value_pages_cache = self.get_kv_pages() diff --git a/src/MaxText/inference/scripts/sharding_utils.py b/src/MaxText/inference/scripts/sharding_utils.py index 50a5ee4d4b..59ca8e9320 100644 --- a/src/MaxText/inference/scripts/sharding_utils.py +++ b/src/MaxText/inference/scripts/sharding_utils.py @@ -51,32 +51,33 @@ def calculate_matmul_resources( W (weights) has shape (G, K, F). Sharding strategy assumed: - - Data Parallelism: `sD` shards the M dim of A. - - Embedding Parallelism: `sK` shards on the embedding dim of A. - - Tensor Parallelism for W dim: `sK` shards the W dimension of W. - - Tensor Parallelism for F dim: `sF` shards the second weight dim of W. + + * Data Parallelism: `sD` shards the M dim of A. + * Embedding Parallelism: `sK` shards on the embedding dim of A. + * Tensor Parallelism for W dim: `sK` shards the W dimension of W. + * Tensor Parallelism for F dim: `sF` shards the second weight dim of W. Args: - activations_shape: Shape of the activations tensor (M, K). - weights_shape: Shape of the weights tensor (G, K, F). - G is the number of groups if this is a GMM (e.g in MoE layer). - sD: Number of data parallel shards (sD). Must be >= 1. - sK: Sharding factor for the activation embedding dimension. - sW: Sharding factor for the first weight dimension. - sF: Sharding factor for the second weight dimension. - sE: Sharding factor to split up expert weights. - activation_size_bytes: Size of a single element in bytes for the activations. - weight_size_bytes: Size of a single element in bytes for the weights. - ici_latency: The latency overhead of communicating between TPUs. - all_gather_axes: Optional additional output axes that need to be all-gathered (e.g. "M", "F"). - debug: Whether to print intermediate resource calculations. + activations_shape: Shape of the activations tensor (M, K). + weights_shape: Shape of the weights tensor (G, K, F). + G is the number of groups if this is a GMM (e.g in MoE layer). + sD: Number of data parallel shards (sD). Must be >= 1. + sK: Sharding factor for the activation embedding dimension. + sW: Sharding factor for the first weight dimension. + sF: Sharding factor for the second weight dimension. + sE: Sharding factor to split up expert weights. + activation_size_bytes: Size of a single element in bytes for the activations. + weight_size_bytes: Size of a single element in bytes for the weights. + ici_latency: The latency overhead of communicating between TPUs. + all_gather_axes: Optional additional output axes that need to be all-gathered (e.g. "M", "F"). + debug: Whether to print intermediate resource calculations. Returns: - A dictionary with keys: - "t_flops": Estimated FLOPs latency. - "t_comms": Estimated communication latency. - "memory": Estimated memory footprint per device for storing - local shards of activations, weights, and output (bytes). + A dictionary with keys + * "t_flops": Estimated FLOPs latency. + * "t_comms": Estimated communication latency. + * "memory": Estimated memory footprint per device for storing + local shards of activations, weights, and output (bytes). """ M, K_act = activations_shape[0], activations_shape[-1] diff --git a/src/MaxText/input_pipeline/_input_pipeline_utils.py b/src/MaxText/input_pipeline/_input_pipeline_utils.py index 3fad3e1a7a..640e950064 100644 --- a/src/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/src/MaxText/input_pipeline/_input_pipeline_utils.py @@ -167,11 +167,12 @@ def apply_chat_template(example, tokenizer_model, data_column_name): Returns: The modified `example` dictionary. - - The `data_column_name` column will be updated to a list of - messages, each formatted according to the tokenizer's chat template. - - A new column named "is_prompt" will be added, where `True` - indicates a user message (prompt) and `False` indicates an assistant - message (completion). + + * The `data_column_name` column will be updated to a list of + messages, each formatted according to the tokenizer's chat template. + * A new column named "is_prompt" will be added, where `True` + indicates a user message (prompt) and `False` indicates an assistant + message (completion). """ messages = [] is_prompt = [] @@ -486,25 +487,26 @@ def _pad_image_and_mask( items than this maximum, it is padded with zeros. Args: - preprocessed_image (multimodal_utils.PreprocessorOutput): The input numpy arrays to pad. - - For masks, the expected shape is (num_masks, num_tiles). - - For standard images, the shape is (num_images, H, W, C). - - For tiled images, the shape is (num_images, num_tiles, H, W, C). + preprocessed_image (multimodal_utils.PreprocessorOutput): The input numpy arrays to pad. + + * For masks, the expected shape is (num_masks, num_tiles). + * For standard images, the shape is (num_images, H, W, C). + * For tiled images, the shape is (num_images, num_tiles, H, W, C). Returns: - np.ndarray: The tensor, padded with zeros up to the maximum number of + np.ndarray: The tensor, padded with zeros up to the maximum number of items along the first axis. Raises: - ValueError: If the input tensor's dimension is not 2, 4, or 5. - ValueError: If the number of items in the input tensor exceeds the + ValueError: If the input tensor's dimension is not 2, 4, or 5. + ValueError: If the number of items in the input tensor exceeds the allowed maximum. Notes: - - The computation of maximum images ensures that space is reserved in the sequence - for at least one text token. - - The dummy images used for padding are based on the image shape for initialization - of this model (ignoring batch size). + * The computation of maximum images ensures that space is reserved in the sequence + for at least one text token. + * The dummy images used for padding are based on the image shape for initialization + of this model (ignoring batch size). """ if not isinstance(preprocessed_image, multimodal_utils.PreprocessorOutput): raise TypeError(f"Input must be multimodal_utils.PreprocessorOutput, but got {type(preprocessed_image)}") diff --git a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py index 361cd3ea75..c419d1326f 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py @@ -65,15 +65,18 @@ def rekey(ds, key_map=None): def _rekey(x, key_map=None): """Replace the feature keys according to the mapping in `key_map`. + For example, if the dataset returns examples of the format: {'foo': 'something', 'bar': 'something else', 'zoo': 'others'} and key_map = {'boo': 'foo', 'spar': 'bar', 'zoo': None} then this function will return examples with the format {'boo': 'something', 'spar': 'something else'} If a mapping is to None, then the key will be dropped. + Args: x: an example to process. key_map: dictionary mapping new keys to original keys + Returns: A preprocessed example with the format listed above. """ @@ -90,13 +93,16 @@ def reduce_concat_tokens( batch_size=128, ): """Token-preprocessor to concatenate multiple unrelated documents. + If we want to generate examples of exactly the right length, (to avoid wasting space on padding), then we use this function, followed by split_tokens. + Args: dataset: a tf.data.Dataset with dictionaries containing the key feature_key. feature_key: an string batch_size: an integer - how many documents to concatenate into one + Returns: a dataset """ @@ -118,14 +124,18 @@ def split_tokens( feature_key="targets", ): """Split examples into multiple examples each. + The intended use case is to break up long examples for use in unsupervised transfer-learning. + This function is generally preceded by select_random_chunk. + Args: dataset: a tf.data.Dataset with dictionaries containing the key feature_key. max_tokens_per_segment: an integer, the maximum number of tokens in each segment. Only the final segment may be shorter. feature_key: a string, the feature to split + Returns: a dataset """ diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index 271728255e..8ca4b19ac6 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -146,10 +146,11 @@ def __call__( **kwargs: Arbitrary keyword arguments. Returns: - A tuple containing: - - updated_kv_caches: A list of updated KV caches. - - hidden: The hidden states (Q, d_model). - - aux_hidden_states: A list of auxiliary hidden states. + A tuple containing + + * updated_kv_caches: A list of updated KV caches. + * hidden: The hidden states (Q, d_model). + * aux_hidden_states: A list of auxiliary hidden states. Raises: ValueError: If the model is not an instance of `nnx.Module`. @@ -245,10 +246,11 @@ def __call__( **kwargs: Arbitrary keyword arguments. Returns: - A tuple containing: - - updated_kv_caches: A list of updated KV caches. - - hidden: The hidden states. - - aux_hidden_states: A list of auxiliary hidden states. + A tuple containing + + * updated_kv_caches: A list of updated KV caches. + * hidden: The hidden states. + * aux_hidden_states: A list of auxiliary hidden states. """ with self.mesh: kv_caches, hidden, aux_hidden_states = self.model(kv_caches, input_ids, attention_metadata, *args, **kwargs) diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 051396ffd4..a65b4d598e 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -628,9 +628,10 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod chunked prefill. Returns: - A list containing two elements: - - The prefill key-value cache, reconstructed from the MLA cache, or None. - - The autoregressive key-value cache, reconstructed from the MLA cache, or None. + A list containing two elements + + * The prefill key-value cache, reconstructed from the MLA cache, or None. + * The autoregressive key-value cache, reconstructed from the MLA cache, or None. """ prefill_mla_cache, ar_mla_cache = self.MlaKVCache_0( diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index a637034250..9ccc795506 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -974,14 +974,14 @@ def wrap_ragged_attention( Wraps the GQA function with appropriate sharding. Args: - q: Query tensor. - k: Key tensor. - v: Value tensor. - lengths: Sequence lengths. - block_size: Block size for attention. + q: Query tensor. + k: Key tensor. + v: Value tensor. + lengths: Sequence lengths. + block_size: Block size for attention. Returns: - A tuple containing the output, max, and sum tensors. + A tuple containing the output, max, and sum tensors. """ # Use the original gqa function to get the attention output local_out, (local_sum, local_max) = gpu_pallas_decode_attention.gqa( @@ -1508,15 +1508,16 @@ def compute_local_attention( Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py Args: - attn_weights (Array): Product of query and key - value (Array): Current value - aqt_rng (PRNGKey | None): Optional rng + attn_weights (Array): Product of query and key + value (Array): Current value + aqt_rng (PRNGKey | None): Optional rng Returns: - (local_out, local_max,): where - local_out is local unnormalized output - local_max is the local max of exponentials - local_sum is the sum of exponentials for this chunk, divided by exp(local_max). + (local_out, local_max,), where + + * local_out is local unnormalized output + * local_max is the local max of exponentials + * local_sum is the sum of exponentials for this chunk, divided by exp(local_max). """ b, n_kv, g, t, s = attn_weights.shape n_q = n_kv * g @@ -1735,13 +1736,13 @@ def normalize_cudnn_attention(self, local_outs, local_stats): """Normalize across two cuDNN attentions Args: - local_outs (list): List of outputs entries for each cudnn attention - in shape [b, t, n, d]. - local_stats (list): List of logsumexp entries for each cudnn attention - in shape [b, n, t]. + local_outs (list): List of outputs entries for each cudnn attention + in shape [b, t, n, d]. + local_stats (list): List of logsumexp entries for each cudnn attention + in shape [b, n, t]. Returns: - Array: Combined attention that has been normalized in shape [b, t, n, d]. + Array: Combined attention that has been normalized in shape [b, t, n, d]. """ # reshape stat to have shape [b, n, t, 1] stat0 = local_stats[0].reshape((*local_stats[0].shape, 1)) @@ -1757,12 +1758,12 @@ def normalize_attention(self, local_outs, local_maxes, local_sums): """Normalize across multiple localized attentions Args: - local_outs (list): List of unnormalized outputs entries for each local attention - local_maxes (list): List of max exponentials entries for each local attention - local_sums (list): List of exponential sum entries for each local attention + local_outs (list): List of unnormalized outputs entries for each local attention + local_maxes (list): List of max exponentials entries for each local attention + local_sums (list): List of exponential sum entries for each local attention Returns: - Array: Combined attention that has been normalized + Array: Combined attention that has been normalized """ # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py global_max = functools.reduce(jnp.maximum, local_maxes) diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index e3f1d19505..2bb3661aae 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -884,8 +884,9 @@ def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous Returns: A list containing two elements: - - The prefill key-value cache, or None. - - The autoregressive key-value cache, or None. + + * The prefill key-value cache, or None. + * The autoregressive key-value cache, or None. """ prefill_kv_cache, ar_kv_cache = self.KVCache_0( key=key, diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index fc063415f4..b39fc863fb 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -390,7 +390,7 @@ def get_decoder_layers(self): """Retrieves a list of decoder layer classes based on the `decoder_block` config. Returns: - A list containing one or more `nn.Module` classes for the decoder. + A list containing one or more `nn.Module` classes for the decoder. """ match self.config.decoder_block: case DecoderBlockType.DEFAULT: diff --git a/src/MaxText/layers/embeddings.py b/src/MaxText/layers/embeddings.py index be06b67200..ed1f3577d4 100644 --- a/src/MaxText/layers/embeddings.py +++ b/src/MaxText/layers/embeddings.py @@ -803,15 +803,15 @@ def _find_correction_range( """Computes the range of correction dimensions for rotary positional embeddings. Args: - low_rot (float): Lower bound for the number of rotations. - high_rot (float): Upper bound for the number of rotations. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_position_embeddings (int): Maximum sequence length. - truncate (bool): Whether to floor lower bound and ceil upper bound. + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_position_embeddings (int): Maximum sequence length. + truncate (bool): Whether to floor lower bound and ceil upper bound. Returns: - tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = self._find_correction_dim(low_rot, dim, base, max_position_embeddings) high = self._find_correction_dim(high_rot, dim, base, max_position_embeddings) @@ -1002,6 +1002,7 @@ class LlamaVisionRotaryEmbedding(nnx.Module): cast_as_fprop_dtype: bool = True whether to cast the output to the fprop dtype fprop_dtype: DType = jnp.bfloat16 the dtype of the output rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module. + Returns: jax.Array of shape [batch_size_times_tiles, num_patches_incl_cls, num_heads, head_dim] where vision rotary position embeddings are applied. @@ -1375,9 +1376,10 @@ def _interpolate_single(self, t: int, h: int, w: int) -> tuple[Array, Array]: w: Target width in patches Returns: - Tuple of (indices, weights) where: - - indices: [4, h*w] indices into pos_embed for 4 corners - - weights: [4, h*w] bilinear weights for 4 corners + Tuple of (indices, weights) where + + * indices: [4, h*w] indices into pos_embed for 4 corners + * weights: [4, h*w] bilinear weights for 4 corners """ N = self.num_grid_per_side diff --git a/src/MaxText/layers/gemma3.py b/src/MaxText/layers/gemma3.py index 1906af5aa8..4480375649 100644 --- a/src/MaxText/layers/gemma3.py +++ b/src/MaxText/layers/gemma3.py @@ -665,8 +665,10 @@ def _get_posemb( def __call__(self, inputs, deterministic, train=False): """ViT model that transforms image inputs to image embeddings. + Args: inputs: jnp.array shaped [B, N, H, W, C], e.g. [4, 1, 896, 896, 3] + Returns: jnp.array for image embeddings, shaped [B, N, P, D], e.g. [4, 1, 256, 1152] """ diff --git a/src/MaxText/layers/llama4.py b/src/MaxText/layers/llama4.py index 04828881fe..089df2b422 100644 --- a/src/MaxText/layers/llama4.py +++ b/src/MaxText/layers/llama4.py @@ -286,8 +286,9 @@ def determine_is_moe_layer(layer_id: int, interleave_moe_layer_step: int) -> boo Determines whether the given layer at `layer_id` is MoE layer. This function implements a striding pattern. For example: - - If moe_layer_stride is 1, all layers are MoE layers. - - If moe_layer_stride is 2, layers with index 1, 3, 5, ... are MoE layers. + + * If moe_layer_stride is 1, all layers are MoE layers. + * If moe_layer_stride is 2, layers with index 1, 3, 5, ... are MoE layers. Args: layer_id: The 0-based index of the layer being checked. @@ -764,8 +765,8 @@ def __call__( deterministic: Whether to use deterministic mode (disables dropout) Returns: - Final hidden states from the vision encoder of shape: - [batch_size * num_images, num_tiles, num_patches, vision_output_dim_for_vit] + Final hidden states from the vision encoder of shape + `[batch_size * num_images, num_tiles, num_patches, vision_output_dim_for_vit]` """ # Reshape pixel values to combine batch and num_tiles dimensions b, t, c, h, w = pixel_values.shape diff --git a/src/MaxText/layers/multi_token_prediction.py b/src/MaxText/layers/multi_token_prediction.py index a3201de36e..54cb57bc96 100644 --- a/src/MaxText/layers/multi_token_prediction.py +++ b/src/MaxText/layers/multi_token_prediction.py @@ -140,15 +140,15 @@ def __call__( """Applies MTP combination, projection, and transformer processing. Args: - prev_hidden_state: Shape [batch, seq_len, hidden_size]. - target_token_embedding: Embedding for token t+k. Shape [batch, seq_len, embed_dim]. - position_ids: Shape [batch, seq_len]. - decoder_segment_ids: Shape [batch, seq_len] or None. - deterministic: Whether to disable dropout. - model_mode: Operational mode (train, eval, decode). + prev_hidden_state: Shape [batch, seq_len, hidden_size]. + target_token_embedding: Embedding for token t+k. Shape [batch, seq_len, embed_dim]. + position_ids: Shape [batch, seq_len]. + decoder_segment_ids: Shape [batch, seq_len] or None. + deterministic: Whether to disable dropout. + model_mode: Operational mode (train, eval, decode). Returns: - Processed hidden state. Shape [batch, seq_len, hidden_size]. + Processed hidden state. Shape [batch, seq_len, hidden_size]. """ embedding_norm = self.embedding_norm(target_token_embedding) hidden_state_norm = self.hidden_state_norm(prev_hidden_state) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index d37827bb07..944c0781cf 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -270,7 +270,7 @@ def vmap_gather(self, xs, ids, ids_dim): Returns: The per-stage gathered values. The shape is xs.shape but with ids_dim size - replaced with [num_stages]. + replaced with [num_stages]. """ def _gather_one(x, i): diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index e0e20d11ca..840bbdd533 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -1236,10 +1236,10 @@ def __init__( def __call__(self, hidden: Array) -> Array: """ Args: - hidden: Input tensor of shape (batch, seq_len, base_hidden_size) after spatial reordering + hidden: Input tensor of shape (batch, seq_len, base_hidden_size) after spatial reordering Returns: - Output tensor of shape (batch, seq_len//merge_size**2, out_hidden_size) - spatially merged + Output tensor of shape (batch, seq_len//merge_size**2, out_hidden_size) - spatially merged """ # Get dimensions spatial_merge_size = self.config.spatial_merge_size_for_vit @@ -1334,10 +1334,10 @@ def __init__( def __call__(self, hidden_state: Array) -> Array: """ Args: - hidden_state: Input tensor of shape (..., hidden_size) - supports packed sequences + hidden_state: Input tensor of shape (..., hidden_size) - supports packed sequences Returns: - Output tensor of shape (..., hidden_size) + Output tensor of shape (..., hidden_size) """ hidden_state = self.linear_fc1(hidden_state) hidden_state = jax.nn.gelu(hidden_state) @@ -1401,9 +1401,10 @@ def __init__( def __call__(self, hidden_states: Array) -> Array: """ Args: - hidden_states: Input tensor of shape (batch, in_channels, temporal*patch_size, height*patch_size, width*patch_size) + hidden_states: Input tensor of shape (batch, in_channels, temporal*patch_size, height*patch_size, width*patch_size) + Returns: - Output tensor of shape (batch, T*H*W, embed_dim) where T, H, W are the number of patches + Output tensor of shape (batch, T*H*W, embed_dim) where T, H, W are the number of patches """ hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) hidden_states = self.proj(hidden_states) @@ -1467,14 +1468,14 @@ def __call__( ) -> Array: """ Args: - hidden_states: Input tensor of shape (batch, T*H*W, hidden_size) - num_frames: Number of temporal frames (static) - height: Height in patches (static) - width: Width in patches (static) - deterministic: Whether to use deterministic mode (disable dropout) + hidden_states: Input tensor of shape (batch, T*H*W, hidden_size) + num_frames: Number of temporal frames (static) + height: Height in patches (static) + width: Width in patches (static) + deterministic: Whether to use deterministic mode (disable dropout) Returns: - Output tensor of shape (batch, T*H*W, hidden_size) + Output tensor of shape (batch, T*H*W, hidden_size) """ # Pass through attention with static dimensions via rope_kwargs rope_kwargs = { @@ -1541,13 +1542,13 @@ def __call__( ) -> Array: """ Args: - x: Input tensor of shape (batch, T*H*W, hidden_size) - num_frames: Number of temporal frames (static) - height: Height in patches (static)i - width: Width in patches (static) + x: Input tensor of shape (batch, T*H*W, hidden_size) + num_frames: Number of temporal frames (static) + height: Height in patches (static)i + width: Width in patches (static) Returns: - Output tensor of shape (batch, T*H*W, hidden_size) + Output tensor of shape (batch, T*H*W, hidden_size) """ x = x + self.attn(self.ln1(x), num_frames=num_frames, height=height, width=width) y = self.ln2(x) @@ -1614,13 +1615,13 @@ def __call__( ): """ Args: - hidden_states: Input visual tokens of shape (batch, in_channels, T*patch_size, H*patch_size, W*patch_size) - deterministic: Whether to use deterministic mode + hidden_states: Input visual tokens of shape (batch, in_channels, T*patch_size, H*patch_size, W*patch_size) + deterministic: Whether to use deterministic mode Returns: - Tuple of: - - encoder_output: shape (batch, T*H*W, hidden_size_for_vit) - - deep_features: List of intermediate features, each of shape (batch, T*H*W, out_hidden_size) + Tuple of: + - encoder_output: shape (batch, T*H*W, hidden_size_for_vit) + - deep_features: List of intermediate features, each of shape (batch, T*H*W, out_hidden_size) """ _, _, num_frames, height, width = hidden_states.shape num_frames = num_frames // self.config.temporal_patch_size_for_vit @@ -1672,10 +1673,10 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): def __call__(self, hidden_states: Array) -> Array: """ Args: - hidden_states: Encoder output of shape (batch, T*H*W, hidden_size_for_vit) + hidden_states: Encoder output of shape (batch, T*H*W, hidden_size_for_vit) Returns: - Projected output of shape (batch, T*H*W//merge_size**2, out_hidden_size_for_vit) + Projected output of shape (batch, T*H*W//merge_size**2, out_hidden_size_for_vit) """ output = self.merger(hidden_states) return output diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 5b5811e6ba..ce9f9cc751 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -489,10 +489,10 @@ def unbox_logicallypartioned(boxed_pytree): """Unboxes the flax.LogicallyPartitioned pieces Args: - boxed_pytree: a pytree that includes LogicallyPartitioned + boxed_pytree: a pytree that includes `LogicallyPartitioned` leaves. Returns: - a pytree where all all LogicallyPartitioned leaves have been unboxed. + a pytree where all all `LogicallyPartitioned` leaves have been unboxed. """ return jax.tree_util.tree_map( lambda x: x.unbox() if isinstance(x, flax.linen.spmd.LogicallyPartitioned) else x, @@ -965,7 +965,7 @@ def get_batch_seq_len_for_mode(config, model_mode): (e.g., PREFILL, AUTOREGRESSIVE, TRAIN). Returns: - A tuple of (batch_size, seq_len). + A tuple of `(batch_size, seq_len)`. """ if model_mode == MODEL_MODE_PREFILL: # Prefill mode: Process one full-length prompt. diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index ade242d433..2e6d09fa4f 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -68,15 +68,16 @@ def from_config( This function loads a model from a checkpoint. Args: - config: Config object. - devices: Sequence of devices to use for the model. If None, use all - available devices. + config: Config object. + devices: Sequence of devices to use for the model. If None, use all + available devices. Returns: - Transformer: The loaded model instance (only the model) + Transformer: The loaded model instance (only the model) - Example: - model = from_config(config) + Example:: + + model = from_config(config) """ devices_array = maxtext_utils.create_device_mesh(config, devices) diff --git a/src/MaxText/multimodal_utils.py b/src/MaxText/multimodal_utils.py index 23d841315d..50fd93ba43 100644 --- a/src/MaxText/multimodal_utils.py +++ b/src/MaxText/multimodal_utils.py @@ -134,10 +134,10 @@ def get_factors(dividend: int): no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}. Args: - dividend (int): The number to find factors for. + dividend (int): The number to find factors for. Returns: - set: A set containing all factors of the number. + set: A set containing all factors of the number. """ factors_set = set() @@ -176,13 +176,13 @@ def get_best_resolution( Get the best resolution for the image based on the possible resolutions. Args: - img_height (int): The height of the image. - image_width (int): The width of the image. - possible_resolutions (list): A list of possible resolutions. - resize_to_max_canvas (bool): Whether to resize to max canvas or not. + img_height (int): The height of the image. + image_width (int): The width of the image. + possible_resolutions (list): A list of possible resolutions. + resize_to_max_canvas (bool): Whether to resize to max canvas or not. Returns: - tuple: The best resolution for the image. + tuple: The best resolution for the image. """ if resize_to_max_canvas: return max(possible_resolutions, key=lambda x: x[0] * x[1]) @@ -202,18 +202,18 @@ def pad_to_best_fit_jax( If smaller, it's padded on the right and bottom. Args: - images (np.ndarray): - The images to process. Expected shape (..., H, W, C). - target_size (tuple[int, int]): - The target (height, width). - background_color (int | tuple[int, ...] | None): - The color to use for padding. - If int, it's used for the first channel and subsequent channels are padded with 0. - If tuple, its length must match the number of channels in the image. - Defaults to 0. + images (np.ndarray): + The images to process. Expected shape (..., H, W, C). + target_size (tuple[int, int]): + The target (height, width). + background_color (int | tuple[int, ...] | None): + The color to use for padding. + If int, it's used for the first channel and subsequent channels are padded with 0. + If tuple, its length must match the number of channels in the image. + Defaults to 0. Returns: - np.ndarray: The processed images of shape (..., target_height, target_width, C). + np.ndarray: The processed images of shape (..., target_height, target_width, C). """ original_shape = images.shape num_dims = len(original_shape) @@ -277,12 +277,12 @@ def pad_to_max_tiles(images: np.ndarray, max_num_tiles: int = LLAMA4_TILES_PAD_T Pads the image tiles to the maximum number of tiles using JAX. Args: - images: The input image tiles with shape (num_tiles, C, H, W). - max_num_tiles: The maximum number of tiles to pad to. + images: The input image tiles with shape (num_tiles, C, H, W). + max_num_tiles: The maximum number of tiles to pad to. Returns: - The padded image tiles with shape (max_num_tiles, C, H, W). - The mask indicating valid tiles with shape (max_num_tiles,). + The padded image tiles with shape (max_num_tiles, C, H, W). + The mask indicating valid tiles with shape (max_num_tiles,). """ num_tiles, num_channels, height, width = images.shape if num_tiles > max_num_tiles: @@ -307,13 +307,13 @@ def split_to_tiles(images: np.ndarray, num_tiles_height: int, num_tiles_width: i Splits an image tensor into tiles using JAX. Args: - images: The input image tensor with shape (batch_size, num_channels, height, width). - num_tiles_height: The number of tiles along the height dimension. - num_tiles_width: The number of tiles along the width dimension. + images: The input image tensor with shape (batch_size, num_channels, height, width). + num_tiles_height: The number of tiles along the height dimension. + num_tiles_width: The number of tiles along the width dimension. Returns: - The tiled image tensor with shape: - (batch_size * num_tiles_height * num_tiles_width, num_channels, height // num_tiles_height, width // num_tiles_width). + The tiled image tensor with shape: + (batch_size * num_tiles_height * num_tiles_width, num_channels, height // num_tiles_height, width // num_tiles_width). """ images = np.transpose(images, (2, 0, 1)) # Change to (num_channels, height, width) num_channels, height, width = images.shape @@ -392,7 +392,7 @@ def pre_process_llama4_image(image: np.ndarray | list[np.ndarray]) -> Preprocess Returns: The pre-processed image in np.array [N, NUM_TILES, C, TILE_SIZE, TILE_SIZE]. - Example: + Example:: image of (536, 640, 3), its best_resolution = (672, 672), image split into 4 tiles of (336, 336) Additional global tile of (336, 336) is added, and the final output image_tiles is (1, 5, 3, 336, 336). @@ -691,7 +691,7 @@ def get_num_tokens_for_this_image(this_aspect_ratio, num_patches_per_chunk): Args: aspect_ratio: A tuple (ratio_h, ratio_w) representing the number of tiles - along height and width. + along height and width. num_patches_per_chunk: The number of patch tokens per image tile. Returns: @@ -775,13 +775,13 @@ def insert_sequence( This function is fully vectorized and operates on a batch of token sequences. Args: - tokens: A 1D or 2D array of input tokens. - at: The token ID to find and replace with the sequence. - sequence: The list of new token IDs to insert. - max_num_images: The maximum number of times `at` can appear. + tokens: A 1D or 2D array of input tokens. + at: The token ID to find and replace with the sequence. + sequence: The list of new token IDs to insert. + max_num_images: The maximum number of times `at` can appear. Returns: - The modified token array with the sequences inserted. + The modified token array with the sequences inserted. """ # Ensure input is a 2D array (batch) original_dim = tokens.ndim diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index 29610ef582..ed4817d952 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -53,13 +53,13 @@ def generate_responses( Generate responses for a batch of prompts across potentially multiple passes. Args: - tmvp_config: Configuration object - prompts: List of prompts to generate responses for - rl_cluster: Model cluster for generation - num_passes: Number of generation passes + tmvp_config: Configuration object + prompts: List of prompts to generate responses for + rl_cluster: Model cluster for generation + num_passes: Number of generation passes Returns: - List of lists containing responses for each prompt across passes + List of lists containing responses for each prompt across passes """ multiple_call_responses = [[] for _ in range(len(prompts))] eval_strategy = tmvp_config.generation_configs[tmvp_config.eval_sampling_strategy] @@ -90,13 +90,13 @@ def score_responses(tmvp_config, question, responses, answer): Score a set of responses for a single question. Args: - tmvp_config: Configuration object - question: The evaluation question - responses: List of generated responses for this question - answer: The correct answer + tmvp_config: Configuration object + question: The evaluation question + responses: List of generated responses for this question + answer: The correct answer Returns: - Tuple of (is_correct, is_partially_correct, has_correct_format) + Tuple of (is_correct, is_partially_correct, has_correct_format) """ match_format = utils_rl.get_match_format_regex(tmvp_config) match_numbers = utils_rl.get_match_numbers_regex(tmvp_config) @@ -156,15 +156,15 @@ def evaluate( Computes accuracy and percentage of outputs matching the format. Args: - tmvp_config: Configuration object - dataset: The evaluation dataset - rl_cluster: Model cluster for generation. - num_passes: Number of generation passes - corr_lst: If True, only include correct responses in the list - make_lst: If True, return a list of (question, answer, responses) + tmvp_config: Configuration object + dataset: The evaluation dataset + rl_cluster: Model cluster for generation. + num_passes: Number of generation passes + corr_lst: If True, only include correct responses in the list + make_lst: If True, return a list of (question, answer, responses) Returns: - Tuple of statistics and optionally the response list + Tuple of statistics and optionally the response list """ response_lst = [] corr = 0 diff --git a/src/MaxText/sequence_packing.py b/src/MaxText/sequence_packing.py index 5a4387dd14..c99a9d7659 100644 --- a/src/MaxText/sequence_packing.py +++ b/src/MaxText/sequence_packing.py @@ -105,11 +105,14 @@ def _pack_with_tf_ops( dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int], pad_id: int ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. - Helper for pack_dataset() Uses tf.while_loop. + + Helper for `pack_dataset()`. Uses `tf.while_loop`. + Args: dataset: a dataset containing padded batches of examples. keys: a list of strings key2length: an dict from feature-key to integer + Returns: a dataset. """ @@ -132,10 +135,13 @@ def write_packed_example(partial, outputs): def map_fn(x): """Internal function to flat_map over. + Consumes a batch of input examples and produces a variable number of output examples. + Args: x: a single example + Returns: a tf.data.Dataset """ @@ -149,10 +155,12 @@ def map_fn(x): def body_fn(i, partial, outputs): """Body function for while_loop. + Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray + Returns: A triple containing the new values of the inputs. """ diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 0126489740..8d47f4ae69 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -164,11 +164,12 @@ def _analyze_sharding(params, mesh, valid_target_mesh_axes): valid_target_mesh_axes: A set of mesh axis names that are considered valid targets for sharding. Returns: - A tuple containing: - - unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be - unsharded on the target axes. - - problematic_tensors_details (list): A list of dictionaries, where each - dictionary contains details about a tensor that is not sharded on any of the target axes. + A tuple containing + + * unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be + unsharded on the target axes. + * problematic_tensors_details (list): A list of dictionaries, where each + dictionary contains details about a tensor that is not sharded on any of the target axes. """ unsharded_params_total_size = 0 # Initialize a counter for the size of unsharded parameters. problematic_tensors_details = [] # Initialize a list to store details of problematic tensors. @@ -355,10 +356,11 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): state_mesh_shardings: Train state mesh shardings containing params and opt_state Returns: - A tuple of (prev_params_shardings, updated_state_mesh_shardings): - - prev_params_shardings: Original parameter shardings before the update - - updated_state_mesh_shardings: State mesh shardings with updated params field - (unchanged if shard_optimizer_over_data is False) + A tuple of (prev_params_shardings, updated_state_mesh_shardings) + + * prev_params_shardings: Original parameter shardings before the update + * updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False) """ prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: diff --git a/src/MaxText/tokenizer.py b/src/MaxText/tokenizer.py index 3d3b5c9637..9f79f66bd2 100644 --- a/src/MaxText/tokenizer.py +++ b/src/MaxText/tokenizer.py @@ -94,14 +94,14 @@ def encode( Encodes a string into a list of token IDs. Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_tokens (`"all"|set[str]`): allowed special tokens in string - disallowed_tokens (`"all"|set[str]`): special tokens that raise an error when in string + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_tokens (`"all"|set[str]`): allowed special tokens in string + disallowed_tokens (`"all"|set[str]`): special tokens that raise an error when in string Returns: - list[int]: A list of token IDs. + list[int]: A list of token IDs. By default, setting disallowed_special=() encodes a string by ignoring special tokens. Specifically: @@ -150,10 +150,10 @@ def decode(self, t) -> str: Decodes a list of token IDs into a string. Args: - t (list[int]): The list of token IDs to be decoded. + t (list[int]): The list of token IDs to be decoded. Returns: - str: The decoded string. + str: The decoded string. """ # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. return self.model.decode(t) diff --git a/src/MaxText/train_tokenizer.py b/src/MaxText/train_tokenizer.py index ba4af3c84b..8f2e65a98e 100644 --- a/src/MaxText/train_tokenizer.py +++ b/src/MaxText/train_tokenizer.py @@ -43,10 +43,12 @@ def _dump_chars_to_textfile(dataset: tf.data.Dataset, maxchars: int = int(1e7), data_keys=("text",)) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. + Args: dataset: tf.dataset containing string-data. maxchars: int: approximate number of characters to save from dataset. data_keys: tuple[str]: what keys in dataset to dump from. + Returns: name of temp file with dataset bytes, exact number of characters dumped. """ @@ -74,6 +76,7 @@ def _train_sentencepiece( data_keys=("text",), ): """Train SentencePiece tokenizer from subset of tf dataset. + Args: dataset: tf.dataset vocab_size: int: size of vocab tokens to train. @@ -84,6 +87,7 @@ def _train_sentencepiece( are 0.9995 for languages with rich character set like Japanese or Chinese and 1.0 for other languages with small character set. data_keys: tuple[str]: keys of dataset to use for training. + Returns: path to the trained sentencepiece vocabulary model. """ diff --git a/src/MaxText/vocabulary_tiling.py b/src/MaxText/vocabulary_tiling.py index 61345ffe55..fc8d169b11 100644 --- a/src/MaxText/vocabulary_tiling.py +++ b/src/MaxText/vocabulary_tiling.py @@ -51,6 +51,7 @@ def vocab_tiling_linen_loss( model: The Linen model instance. params: The model parameters. is_train: A boolean indicating if the model is in training mode. + Returns: The total cross-entropy loss computed via vocab tiling. """ From 6f0547e36fea60d045d532320468e1bdab39bcfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Melissa=20Weber=20Mendon=C3=A7a?= Date: Fri, 19 Dec 2025 17:51:11 -0300 Subject: [PATCH 3/3] Update docs build action --- .github/workflows/check_docs_build.yml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check_docs_build.yml b/.github/workflows/check_docs_build.yml index 393bd98fd2..8ee67809aa 100644 --- a/.github/workflows/check_docs_build.yml +++ b/.github/workflows/check_docs_build.yml @@ -17,15 +17,20 @@ jobs: with: persist-credentials: false - - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + - name: Install uv and set the Python version + uses: astral-sh/setup-uv@eb1897b8dc4b5d5bfe39a428a8f2304605e0983c # v7.0.0 with: python-version: '3.12' - cache: 'pip' # caching pip dependencies + enable-cache: true + + - name: Set venv + run: uv venv --python 3.12 $GITHUB_WORKSPACE/venv - name: Install dependencies - run: pip install -r dependencies/requirements/requirements_docs.txt + run: . $GITHUB_WORKSPACE/venv/bin/activate && uv pip install -r dependencies/requirements/requirements_docs.txt - name: Build documentation run: | + . $GITHUB_WORKSPACE/venv/bin/activate + uv pip install -e . --no-deps sphinx-build -W -b html docs docs/_build/html