Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 176 additions & 25 deletions scripts/fsdp_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import gc
import itertools
import math
import os
import time
from dataclasses import dataclass
from datetime import timedelta
from functools import partial
from timeit import default_timer as timer
from typing import Any, Type

Expand All @@ -17,10 +20,12 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel
from torch.distributed.device_mesh import DeviceMesh
from transformers import AutoConfig

from maester.checkpoint import CheckpointManager
from maester.datasets import build_hf_data_loader, create_tokenizer, MosaicDataset
from maester.datasets import (MosaicDataset, build_hf_data_loader,
create_tokenizer)
from maester.datasets.experimental_otf import build_experimental_data_loader
from maester.datasets.mosaic_dataset import MosaicDataLoader
from maester.log_utils import init_logger, logger
Expand All @@ -29,10 +34,89 @@
from maester.metrics import build_gpu_memory_monitor, build_metric_logger
from maester.models import (model_name_to_cls, model_name_to_tokenizer,
models_config)
from maester.parallelize_llama import ParallelDims, parallelize_llama
from maester.parallelisms import ParallelDims, parallelize_llama
from maester.profiling import maybe_enable_profiling
from maester.utils import (dist_max, dist_mean, get_num_flop_per_token, get_num_params, get_peak_flops,
init_distributed, set_pg_timeouts)
from maester.utils import (dist_max, dist_mean, get_num_flop_per_token,
get_num_params, get_peak_flops, init_distributed,
set_pg_timeouts)

# not merged into pytorch so define it here
# from torch.distributed.utils import _sync_module_states_with_mesh
from torch.distributed.utils import _verify_param_shape_across_processes
def _sync_module_states_with_mesh(module: torch.nn.Module, mesh: "DeviceMesh") -> None:
"""
Broadcast from the module states of the first rank of ``mesh`` to other ranks
within the same ``mesh``.
This API is similar to ``_sync_module_states`` but is designed for DeviceMesh.
Instead of extending ``_sync_module_states``, creating a new API makes the
samentic simpler (e.g., the meaning of ``src`` is different for PG and
DeviceMesh).
"""
module_states: list[torch.Tensor] = []

# Lazy import to avoid circular dependency
from torch.distributed._tensor import DTensor

for state in itertools.chain(module.parameters(), module.buffers()):
module_states.append(state.to_local() if isinstance(state, DTensor) else state)

with torch.no_grad():
pg = mesh.get_group()
src = dist.get_process_group_ranks(pg)[0]
_verify_param_shape_across_processes(pg, module_states)

for state in module_states:
# `dist._broadcast_coalesced` will increase the peak memory usage due to
# recordStream. We can implement the broadcast coalescing to speed up.
dist.broadcast(state, src, group=pg)

def reshape(self, new_shape: tuple[int, ...], new_dim_names: tuple[str, ...] | None = None) -> DeviceMesh:
"""
Reshape the DeviceMesh to a new shape while preserving the total number of devices.

Args:
new_shape (Tuple[int, ...]): The new shape for the DeviceMesh.
new_dim_names (Optional[Tuple[str, ...]]): New names for the dimensions of the reshaped mesh.
If provided, must have the same length as new_shape. If not provided, dimension names will be kept if compatible, otherwise reset.

Returns:
The same DeviceMesh object with the reshaped structure.

Raises:
ValueError: If the new shape is incompatible with the total number of devices,
or if new_dim_names is provided but has a different length than new_shape.
"""
if math.prod(new_shape) != self.mesh.numel():
raise ValueError("New shape must have the same number of elements as the current mesh.")

if new_dim_names is not None:
if len(new_dim_names) != len(new_shape):
raise ValueError("new_dim_names must have the same length as new_shape.")
if len(set(new_dim_names)) != len(new_dim_names):
raise ValueError("Each name in new_dim_names must be unique.")

# Reshape the mesh tensor in-place
self.mesh = self.mesh.reshape(new_shape)

# Update mesh_dim_names
if new_dim_names is not None:
self.mesh_dim_names = new_dim_names
elif self.mesh_dim_names is not None and len(self.mesh_dim_names) != len(new_shape):
self.mesh_dim_names = None # Reset if incompatible

# Update the coordinate of the current rank on the reshaped mesh
rank_coords = (self.mesh == self.get_rank()).nonzero()
self._coordinate_on_dim = rank_coords[0].tolist() if rank_coords.size(0) > 0 else None

# Only reinitialize process groups if the number of dimensions has changed
if len(new_shape) != self.ndim:
self._init_process_groups()

return self

# Add the reshape method to the DeviceMesh class
DeviceMesh.reshape = reshape


class DatasetConfig(BaseModel):
data_logical_shards: int = 1024
Expand All @@ -48,11 +132,14 @@ class Config(BaseModel):
model_config = ConfigDict(frozen=True, protected_namespaces=(), arbitrary_types_allowed=True)

job_folder: str = "jobs/"
job_name: str = "fineweb-1B-llama2-v2"
job_name: str = "fineweb-1B-llama2-testing"

max_grad_norm: float = 1.0
gc_freq: int = 4
data_parallel_type: str = "fsdp"
data_parallel_degree: int = -1
data_parallel_replicate: int = 1 # for hsdp, not ready to use (set to 1)
context_parallel_degree: int = 1 # not ready for use
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
train_batch_size: int = 2 # per device; 2 * 8 gpus * 32 nodes * 4096 seqlen = 2.1M tokens per batch
Expand Down Expand Up @@ -84,6 +171,7 @@ class Config(BaseModel):
# model
model_name: str = "llama3"
flavor: str = "1B-v2"
num_future_tokens: int = 4
seq_len: int = 4096
norm_type: str = "rmsnorm"

Expand All @@ -92,6 +180,7 @@ class Config(BaseModel):
opt_cfg: dict[str, Any] = dict( # TODO: don't use dict, not validateable
lr = 4e-4, # max lr, schedule reduces it at points
betas = (0.9, 0.95),
weight_decay=0.1,
# foreach=True, # foreach might work where fused doesn't
fused=True
)
Expand All @@ -108,6 +197,10 @@ class Config(BaseModel):
ac_mode: str = "none" # "full" | "selective" | "none"
selective_ac_option: str | int = "op"

# experimental
enable_async_tensor_parallel: bool = False
enable_compiled_autograd: bool = False

# profiling
enable_profiling: bool = False
traces_folder: str = "traces"
Expand Down Expand Up @@ -138,7 +231,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
@record
def main():
init_logger()
logger.info(f"Starting training.")
logger.info(f"Starting job: ")

cfg = Config() # TODO: enable configuring?

Expand All @@ -150,15 +243,44 @@ def main():
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=cfg.data_parallel_degree,
cp=cfg.context_parallel_degree,
tp=cfg.tensor_parallel_degree,
pp=cfg.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=cfg.enable_loss_parallel,
dp_type=cfg.data_parallel_type,
dp_replicate=cfg.data_parallel_replicate,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
init_distributed(cfg)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
if parallel_dims.cp_enabled:
dp_mesh = dp_mesh.reshape(
(dp_mesh.size() // parallel_dims.cp, parallel_dims.cp),
("dp", "cp")
)["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

# if parallel_dims.cp_enabled:
# cp_mesh = world_mesh["cp"]
# context_parallel_ctx = partial(
# context_parallel_buffers,
# cp_rank=cp_mesh.get_local_rank(),
# cp_world_size=cp_mesh.size(),
# )
# else:
# context_parallel_ctx = partial(
# context_parallel_buffers,
# cp_rank=0,
# cp_world_size=1,
# )

# hf_config = AutoConfig.from_pretrained(cfg.tokenizer_name) # for vocab size below TODO: fails on LUMI?

Expand All @@ -172,6 +294,7 @@ def main():
model_config.norm_type = cfg.norm_type
model_config.vocab_size = 32000 # hf_config.vocab_size
model_config.max_seq_len = cfg.seq_len
model_config.n_future_tokens = cfg.num_future_tokens
# del hf_config # only needed for vocab size

with torch.device("meta"):
Expand Down Expand Up @@ -211,14 +334,7 @@ def main():
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)


# build dataloader
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"]
dp_degree = dp_mesh.size()
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0
# data_loader = build_hf_data_loader(
# "c4_mini",
# "src/maester/datasets/c4_mini",
Expand Down Expand Up @@ -273,10 +389,17 @@ def loss_fn(pred, labels):
states={"train_state": train_state},
cfg=cfg,
)
checkpoint.load()
checkpoint_loaded = checkpoint.load()

# TODO: do we want to checkpoint metrics?

if not checkpoint_loaded and parallel_dims.dp_enabled and parallel_dims.dp_replicate > 1:
# sync params if hsdp
replicate_mesh = dp_mesh.reshape(
(parallel_dims.dp_replicate, dp_mesh.size() // parallel_dims.dp_replicate)
)
_sync_module_states_with_mesh(model, replicate_mesh)

data_iterator = iter(data_loader)

logger.info(f"Training starts at step {train_state.step}")
Expand Down Expand Up @@ -305,23 +428,51 @@ def loss_fn(pred, labels):
ntokens_since_last_log += labels.numel()
data_loading_times.append(timer() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# non-pp loss parallel, pp is not implemented
with loss_parallel_ctx():
pred = sharded_model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
# with context_parallel_ctx(
# buffers=[input_ids, labels, model.freqs_cis],
# seq_dims=[1,1,0],
# keep_orig_buffers=[False, False, True]
# ):
with contextlib.nullcontext():
input_ids = input_ids.cuda()
labels = labels.cuda()

# non-pp loss parallel, pp is not implemented
with loss_parallel_ctx():
z = sharded_model.trunk(input_ids, sharded_model.freqs_cis) # (bsz, seq_len, dim)
d = z.detach()
d.requires_grad_(True)
# print(f"d shape: {d.shape}, requires_grad: {d.requires_grad}")

for i, head in enumerate(sharded_model.heads.values()):
h = head(d, sharded_model.freqs_cis)
# print(f"h shape after head {i}: {h.shape}, requires_grad: {h.requires_grad}")
h = sharded_model.norm(h)
# print(f"h shape after norm {i}: {h.shape}, requires_grad: {h.requires_grad}")
pred = sharded_model.output(h)
# print(f"pred shape for head {i}: {pred.shape}, requires_grad: {pred.requires_grad}")
loss = loss_fn(pred, labels) # TODO: labels for num_future_tokens > 1
# print(f"loss for head {i}: {loss.item()}, requires_grad: {loss.requires_grad}")

# pred.shape=(bs, seq_len, vocab_size)
# need to free before bwd to avoid peaking memory
del h, pred
loss.backward()
# print(f"d.grad after head {i}: {d.grad is not None}, shape: {d.grad.shape if d.grad is not None else None}")

# print(f"Final d.grad shape: {d.grad.shape if d.grad is not None else None}")
# print(f"z grad_fn: {z.grad_fn}")
z.backward(gradient=d.grad)


total_grad_norm = torch.nn.utils.clip_grad_norm_(
sharded_model.parameters(), cfg.max_grad_norm, foreach=True
)

# optimizer step
checkpoint.wait_for_staging()
optimizer.step()
scheduler.step()

Expand Down
6 changes: 6 additions & 0 deletions src/maester/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,9 @@ def load(self, step: int = -1) -> bool:
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds"
)
return True

def wait_for_staging(self) -> None:
"""
To be implemented as part of async checkpointing.
"""
pass
9 changes: 9 additions & 0 deletions src/maester/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@
multiple_of=1024,
rope_theta=500000,
),
"1B-v2": ModelArgs(
dim=1536,
n_layers=24,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=512,
rope_theta=500000,
),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down
Loading