diff --git a/llama.py b/llama.py index 1a7a61f..46b5f48 100644 --- a/llama.py +++ b/llama.py @@ -967,7 +967,7 @@ def main(llama: Path, tokenizer: Path, tp_world: int, pp_world: int, save_to: Pa print("done", flush=True) torch.distributed.barrier() - if rank < (tensor_model_parallel_size * pipeline_model_parallel_size): + if dp_rank == 0: # save the state dict to sharded files os.makedirs(save_to, exist_ok=True) torch.save(