From 144f9e600d955845bd3a8a1853a5be50cefd7877 Mon Sep 17 00:00:00 2001 From: chhsiao Date: Tue, 11 Nov 2025 09:52:32 -0600 Subject: [PATCH 1/2] fixing statement for ddp and pkl path --- gns/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gns/train.py b/gns/train.py index a730e75..4c64001 100644 --- a/gns/train.py +++ b/gns/train.py @@ -202,7 +202,7 @@ def predict(device: str, cfg: DictConfig): example_rollout["loss"] = loss.mean() filename = f"{cfg.output.filename}_ex{example_i}.pkl" filename_render = f"{cfg.output.filename}_ex{example_i}" - filename = os.path.join(cfg.output.path, filename_render) + filename = os.path.join(cfg.output.path, f"{filename_render}.pkl") with open(filename, "wb") as f: pickle.dump(example_rollout, f) if cfg.rendering.mode: @@ -628,6 +628,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist): cfg, rank, device_id, + use_dist ) writer.add_scalar("Loss/valid", valid_loss.item(), step) @@ -698,9 +699,9 @@ def train(rank, cfg, world_size, device, verbose, use_dist): if cfg.training.validation_interval is not None: sampled_valid_example = next(iter(valid_dl)) epoch_valid_loss = validation( - simulator, sampled_valid_example, n_features, cfg, rank, device_id + simulator, sampled_valid_example, n_features, cfg, rank, device_id, use_dist ) - if device == torch.device("cuda"): + if use_dist: torch.distributed.reduce( epoch_valid_loss, dst=0, op=torch.distributed.ReduceOp.SUM ) @@ -807,7 +808,7 @@ def _get_simulator( return simulator -def validation(simulator, example, n_features, cfg, rank, device_id): +def validation(simulator, example, n_features, cfg, rank, device_id, use_dist): ( position, particle_type, @@ -830,7 +831,7 @@ def validation(simulator, example, n_features, cfg, rank, device_id): # Select the appropriate prediction function predict_accelerations = ( simulator.module.predict_accelerations - if isinstance(device_id, int) + if use_dist else simulator.predict_accelerations ) # Get the predictions and target accelerations From 3d4f43ae8637f2888e161eb0f5682770e4fa837d Mon Sep 17 00:00:00 2001 From: Cheng-Hsi Hsiao Date: Tue, 11 Nov 2025 11:26:02 -0600 Subject: [PATCH 2/2] Format with black --- gns/train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gns/train.py b/gns/train.py index 4c64001..52e3152 100644 --- a/gns/train.py +++ b/gns/train.py @@ -628,7 +628,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist): cfg, rank, device_id, - use_dist + use_dist, ) writer.add_scalar("Loss/valid", valid_loss.item(), step) @@ -699,7 +699,13 @@ def train(rank, cfg, world_size, device, verbose, use_dist): if cfg.training.validation_interval is not None: sampled_valid_example = next(iter(valid_dl)) epoch_valid_loss = validation( - simulator, sampled_valid_example, n_features, cfg, rank, device_id, use_dist + simulator, + sampled_valid_example, + n_features, + cfg, + rank, + device_id, + use_dist, ) if use_dist: torch.distributed.reduce(