Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2beebcf
adding necessary files
gulsumgudukbay Nov 10, 2025
f5755bb
add decoupling logic config patch to tests
gulsumgudukbay Nov 10, 2025
ef9c62e
add correct ICI parallelism to tests for decoupled mode
gulsumgudukbay Nov 10, 2025
b6be265
Add tpu_only marker to train compile test
gulsumgudukbay Nov 10, 2025
583eee9
fixing more UTs
gulsumgudukbay Nov 10, 2025
91e51d5
adding decoupling logic, biggest change
gulsumgudukbay Nov 11, 2025
4647288
add tensorboardX stub
gulsumgudukbay Nov 11, 2025
4e7afd6
adding tokamax changes along with some UT fixes
gulsumgudukbay Nov 11, 2025
ebfef26
fixing little UT issues:
gulsumgudukbay Nov 11, 2025
9bf8105
fixing train_tests
gulsumgudukbay Nov 11, 2025
4d6d2b0
removing CI workflows for now to upstream decoupling changes
gulsumgudukbay Nov 12, 2025
ad66362
making jetstream and tunix optional and add is_stub variables
gulsumgudukbay Nov 13, 2025
0e97a31
removing tunix from decoupling logic
gulsumgudukbay Nov 14, 2025
227b980
removing tunix from decoupled mode logic
gulsumgudukbay Nov 14, 2025
06365cb
addressing PR comments
gulsumgudukbay Nov 15, 2025
fe14ece
fixing pylint issues
gulsumgudukbay Nov 17, 2025
58b135b
renaming datasets to local_datasets to avoid confusion with HF datase…
gulsumgudukbay Nov 18, 2025
c6d31bd
pyink fixes
gulsumgudukbay Nov 18, 2025
a9bfd32
Rename GCE_MARKERS to GCP_MARKERS
gulsumgudukbay Nov 19, 2025
b9c7dbf
updating dataset paths and fix gcloud_Stub
gulsumgudukbay Nov 19, 2025
7476b28
making decoupled mode work with upstream updates
gulsumgudukbay Nov 21, 2025
c8d7898
updates for upstream sync UTs
gulsumgudukbay Dec 20, 2025
71f862a
remove context_parallel_strategy config param, todo: add it back later
gulsumgudukbay Dec 21, 2025
8c27a9e
make jax_remove_size_one_mesh_axis_from_type param setting in try blo…
gulsumgudukbay Dec 21, 2025
c5dbfdc
Fix decoupled rampup and mesh configs for bare-metal tests
gulsumgudukbay Dec 21, 2025
be66d5b
remove some stuff from this PR that were meant to be for downstream
gulsumgudukbay Dec 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or
See our installation guide to [install MaxText with pip from PyPI](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-pypi-recommended).

## Decoupled mode
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html).
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/run_maxtext/decoupled_mode.html).

<!-- NEWS START -->

Expand Down
14 changes: 14 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Top-level shim for importing test_utils

This shim lets test modules import `maxtext.tests`.

"""

from importlib import import_module as _imp

try:
test_utils = _imp("maxtext.tests.test_utils") # noqa: F401
except Exception: # pragma: no cover - fail silently if tests not present
pass

__all__ = ["test_utils"]
56 changes: 25 additions & 31 deletions local_datasets/generate_tfds_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
python local_datasets/generate_tfds_metadata.py \
--root local_datasets/c4_en_dataset_minimal \
--version 3.1.0 \
--source-version 3.0.1 \
--force

This script creates a tiny TFDS builder and outputs the ``dataset_info.json`` and
Expand All @@ -29,36 +28,22 @@
"""
from __future__ import annotations
import os
import json
import argparse
import tensorflow_datasets as tfds # type: ignore


def ensure_symlink(root: str, source_version: str, version: str) -> str:
"""Ensure a symlink exists from source_version to version under root/c4/en.

Returns the target version directory path.
"""
src = os.path.join(root, "c4", "en", source_version)
dst = os.path.join(root, "c4", "en", version)
if not os.path.isdir(src):
raise FileNotFoundError(f"Source version directory not found: {src}")
if not os.path.lexists(dst):
try:
os.symlink(src, dst)
print(f"Created symlink {dst} -> {src}")
except OSError as exc:
print(f"Symlink creation failed (continuing): {exc}")
else:
print(f"Symlink already exists: {dst}")
return dst


def write_metadata(root: str, version_dir: str, dataset_version: str, force: bool = False) -> None:
"""Write TFDS ``dataset_info.json`` and ``features.json`` for local C4 shards."""

info_path = os.path.join(version_dir, "dataset_info.json")
if os.path.exists(info_path) and not force:
print("dataset_info.json already exists; skipping overwrite (use --force to regenerate).")
return
if os.path.exists(info_path):
if force:
os.remove(info_path)
print("Removed existing dataset_info.json due to --force.")
else:
print("dataset_info.json already exists; skipping overwrite (use --force to regenerate).")
return

# Discover shards (we assume they exist and are correct; counts are fixed)
num_shards_train = 8
Expand Down Expand Up @@ -107,6 +92,17 @@ def _generate_examples(self): # type: ignore[override]
info.write_to_directory(version_dir)
print(f"Wrote TFDS dataset_info & features to {version_dir}")

info_path = os.path.join(version_dir, "dataset_info.json")
try:
with open(info_path, "r") as f:
data = json.load(f)
if isinstance(data.get("splits"), dict):
data["splits"] = list(data["splits"].values())
with open(info_path, "w") as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Warning: Could not patch splits in dataset_info.json: {e}")


def main() -> None:
"""CLI entry point for generating TFDS metadata."""
Expand All @@ -121,20 +117,18 @@ def main() -> None:
default="3.1.0",
help="Target version to expose via TFDS",
)
ap.add_argument(
"--source-version",
default="3.0.1",
help="Existing version directory with shards",
)
ap.add_argument(
"--force",
action="store_true",
help="Overwrite existing dataset_info.json if present",
)
args = ap.parse_args()

target_dir = ensure_symlink(args.root, args.source_version, args.version)
write_metadata(args.root, target_dir, args.version, force=args.force)
# Use the version directory directly
version_dir = os.path.join(args.root, "c4", "en", args.version)
if not os.path.isdir(version_dir):
raise FileNotFoundError(f"Version directory not found: {version_dir}")
write_metadata(args.root, version_dir, args.version, force=args.force)
print("Done.")


Expand Down
23 changes: 10 additions & 13 deletions src/MaxText/configs/decoupled_base_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
# Inherit all model defaults from base.yml but override any cloud-coupled paths and disable optional cloud features.
base_config: base.yml
# Inherit all model defaults (PyDantic already does this)but override any cloud-coupled paths and disable
# optional cloud features.

# Output goes to a local relative directory so tests do not require GCS.
base_output_directory: ./maxtext_local_output
base_output_directory: ./local_datasets/gcloud_decoupled_test_logs
run_name: test_decoupled

# Disable checkpointing by default for speed unless a test explicitly enables it.
Expand All @@ -23,7 +23,10 @@ profile_periodically_period: 0
profiler_steps: 0

# Leave dataset-related keys to be overridden by individual tests.
dataset_type: ""
#dataset_type: ""
dataset_path: "local_datasets/c4_en_dataset_minimal/"
dataset_name: 'c4/en:3.1.0'
eval_dataset_name: 'c4/en:3.1.0'

# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
attention: "dot_product"
Expand All @@ -44,6 +47,8 @@ ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_fsdp_parallelism: 1
ici_fsdp_transpose_parallelism: 1
# Allow higher unsharded parameter percentage for small device count
sharding_tolerance: 0.3

# DCN dimensions to 1 (no multi-slice expectation locally).
dcn_data_parallelism: 1
Expand All @@ -68,12 +73,4 @@ goodput_upload_interval_seconds: 0
enable_pathways_goodput: false
enable_gcp_goodput_metrics: false

# Disable any cloud logging / BigQuery or external metric uploads.
enable_cloud_logging: false
upload_metrics_to_bigquery: false
bigquery_project: ""
bigquery_dataset: ""
bigquery_table: ""

# Force local-only behavior for tests: avoid accidental env pickup.
tensorboard_dir: "./maxtext_local_output/tensorboard"
tensorboard_dir: "./local_datasets/gcloud_decoupled_test_logs/tensorboard"
15 changes: 12 additions & 3 deletions src/MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
"""CLI utility for running inference on a single/multi stream(s)."""

import os
from typing import Sequence
from typing import Sequence, Any
import jax
import jax.numpy as jnp

from absl import app

from jetstream.engine import engine_api
from MaxText.gcloud_stub import jetstream, is_decoupled
_config_lib, engine_api, _token_utils, _tokenizer_api, _token_params_ns = jetstream()

from MaxText import max_utils
from MaxText import maxengine
Expand All @@ -34,7 +35,7 @@
_NUM_STREAMS = 1


def _batch_first_result_token(first_tokens: list[engine_api.ResultTokens], batch_size: int):
def _batch_first_result_token(first_tokens: list[Any], batch_size: int):
"""Batches together a list of first result tokens from prefill calls.

This is needed because prefill currently returns the first token as a batch of size 1
Expand Down Expand Up @@ -112,6 +113,14 @@ def main(argv: Sequence[str]) -> None:

metadata = engine.get_tokenizer()
tokenizer_model = engine.build_tokenizer(metadata)
token_params_is_stub = getattr(_token_params_ns, "_IS_STUB", False)
engine_api_is_stub = getattr(engine_api, "_IS_STUB", False)
if is_decoupled() and (token_params_is_stub or engine_api_is_stub):
raise RuntimeError(
"JetStream disabled by DECOUPLE_GCLOUD=TRUE or stubbed; decode requires the JetStream tokenizer. "
"Unset DECOUPLE_GCLOUD or install JetStream to run decode."
)

try:
# TODO: update jetstream.engine.tokenizer_api.Tokenizer to maintain tokenizer state.
has_chat_template = getattr(tokenizer_model.tokenizer, "chat_template", False) # pytype: disable=attribute-error
Expand Down
26 changes: 18 additions & 8 deletions src/MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@

from absl import app

from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
from cloud_tpu_diagnostics.configuration import diagnostic_configuration
from cloud_tpu_diagnostics.configuration import stack_trace_configuration
from MaxText.gcloud_stub import cloud_diagnostics as _cloud_diag
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _cloud_diag()

import jax

Expand Down Expand Up @@ -80,7 +78,10 @@
maybe_monitor_goodput,
maybe_record_goodput,
)
from MaxText.vertex_tensorboard import VertexTensorboardManager
from MaxText.gcloud_stub import vertex_tensorboard_components, is_decoupled
from MaxText.gcloud_stub import cloud_diagnostics as _cloud_diag
VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_components()
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _cloud_diag()

logging.basicConfig()
logging.getLogger("pathwaysutils.elastic.manager").setLevel(logging.INFO)
Expand Down Expand Up @@ -386,7 +387,10 @@ def main(argv: Sequence[str]) -> None:
os.environ["TFDS_DATA_DIR"] = config.dataset_path or ""
vertex_tensorboard_manager = VertexTensorboardManager()
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
if _vertex_tb_is_stub:
max_logging.log("[DECOUPLED NO-OP] skipping Vertex Tensorboard configuration.")
else:
vertex_tensorboard_manager.configure_vertex_tensorboard(config)

# Create the Goodput recorder
recorder = create_goodput_recorder(config)
Expand All @@ -401,9 +405,15 @@ def main(argv: Sequence[str]) -> None:
)
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)

with diagnostic.diagnose(diagnostic_config):
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
# In decoupled mode or when diagnostics are stubbed, skip the diagnose wrapper
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag":
max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.")
with maybe_record_goodput(recorder, GoodputEvent.JOB):
train_loop(config, elastic_manager, recorder)
else:
with diagnostic.diagnose(diagnostic_config):
with maybe_record_goodput(recorder, GoodputEvent.JOB):
train_loop(config, elastic_manager, recorder)


if __name__ == "__main__":
Expand Down
22 changes: 14 additions & 8 deletions src/MaxText/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@
from flax import struct
from flax.nnx import TrainState

from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
from cloud_tpu_diagnostics.configuration import diagnostic_configuration
from cloud_tpu_diagnostics.configuration import stack_trace_configuration
from MaxText.gcloud_stub import cloud_diagnostics as _cloud_diag, vertex_tensorboard_components, is_decoupled
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _cloud_diag()

import transformers

Expand Down Expand Up @@ -92,7 +90,7 @@
maybe_monitor_goodput,
maybe_record_goodput,
)
from MaxText.vertex_tensorboard import VertexTensorboardManager
VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_components()


# pylint: disable=too-many-positional-arguments
Expand Down Expand Up @@ -944,7 +942,10 @@ def main(argv: Sequence[str]) -> None:
os.environ["TFDS_DATA_DIR"] = config.dataset_path
vertex_tensorboard_manager = VertexTensorboardManager()
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
if _vertex_tb_is_stub:
max_logging.log("[DECOUPLED NO-OP] skipping Vertex Tensorboard configuration.")
else:
vertex_tensorboard_manager.configure_vertex_tensorboard(config)

# Create the Goodput recorder
recorder = create_goodput_recorder(config)
Expand All @@ -959,9 +960,14 @@ def main(argv: Sequence[str]) -> None:
)
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)

with diagnostic.diagnose(diagnostic_config):
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag":
max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.")
with maybe_record_goodput(recorder, GoodputEvent.JOB):
train_loop(config, config_inference, recorder)
else:
with diagnostic.diagnose(diagnostic_config):
with maybe_record_goodput(recorder, GoodputEvent.JOB):
train_loop(config, config_inference, recorder)


if __name__ == "__main__":
Expand Down
Loading
Loading