From d07fd8d176671dfe4762d7abab153b5b54b7d6cc Mon Sep 17 00:00:00 2001 From: Dela Houssou Date: Thu, 4 Dec 2025 12:36:11 -0600 Subject: [PATCH 1/2] implement torch.compile --- gns/train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/gns/train.py b/gns/train.py index 52e3152..39c57b0 100644 --- a/gns/train.py +++ b/gns/train.py @@ -202,7 +202,7 @@ def predict(device: str, cfg: DictConfig): example_rollout["loss"] = loss.mean() filename = f"{cfg.output.filename}_ex{example_i}.pkl" filename_render = f"{cfg.output.filename}_ex{example_i}" - filename = os.path.join(cfg.output.path, f"{filename_render}.pkl") + filename = os.path.join(cfg.output.path, filename) with open(filename, "wb") as f: pickle.dump(example_rollout, f) if cfg.rendering.mode: @@ -347,9 +347,11 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device, use_d rank, ) if use_dist: - simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank]) + simulator = torch.compile(serial_simulator) + simulator = DDP(simulator.to("cuda"), device_ids=[rank]) else: - simulator = serial_simulator.to("cuda") + simulator = torch.compile(serial_simulator) + simulator = simulator.to("cuda") optimizer = torch.optim.Adam( simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size ) @@ -361,6 +363,7 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device, use_d cfg.data.noise_std, device, ) + simulator = torch.compile(simulator) optimizer = torch.optim.Adam( simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size ) From 3a03e35bd959b0e423a15a70a7748015d94762c5 Mon Sep 17 00:00:00 2001 From: Dela Houssou Date: Sun, 11 Jan 2026 21:45:39 -0600 Subject: [PATCH 2/2] Update training and simulator logic for rigid body support --- gns/learned_simulator.py | 66 ++++++++++++++++++++++++++++++++++++++++ gns/train.py | 2 ++ 2 files changed, 68 insertions(+) diff --git a/gns/learned_simulator.py b/gns/learned_simulator.py index 8ef5a63..0839d73 100644 --- a/gns/learned_simulator.py +++ b/gns/learned_simulator.py @@ -253,12 +253,64 @@ def _decoder_postprocessor( new_position = most_recent_position + new_velocity # * dt = 1 return new_position + def _enforce_rigid_body_acceleration( + self, + predicted_acceleration: torch.tensor, + particle_types: torch.tensor, + rigid_particle_id: int, + ) -> torch.tensor: + """ + Enforce that all rigid particles have the same acceleration (rigid body constraint). + + This implements translational rigidity by averaging accelerations across all + rigid particles and assigning the same acceleration to each rigid particle. + + Args: + predicted_acceleration: Predicted accelerations with shape (nparticles, dim). + particle_types: Particle types with shape (nparticles). + rigid_particle_id: The particle type ID for rigid particles. + + Returns: + torch.tensor: Modified accelerations with rigid body constraint applied. + """ + # Create a mask for rigid particles + rigid_mask = (particle_types == rigid_particle_id) # shape: (nparticles,) + + # If there are no rigid particles, return original accelerations + if rigid_mask.sum() == 0: + return predicted_acceleration + + # Compute the average acceleration across all rigid particles + # rigid_mask needs to be expanded to match acceleration dimensions + rigid_mask_expanded = rigid_mask.unsqueeze(-1) # shape: (nparticles, 1) + + # Sum accelerations of rigid particles + rigid_acc_sum = (predicted_acceleration * rigid_mask_expanded).sum(dim=0, keepdim=True) + + # Count rigid particles + num_rigid = rigid_mask.sum() + + # Compute average acceleration + avg_rigid_acceleration = rigid_acc_sum / num_rigid # shape: (1, dim) + + # Replace all rigid particle accelerations with the average + # Use torch.where to conditionally replace + modified_acceleration = torch.where( + rigid_mask_expanded, + avg_rigid_acceleration.expand_as(predicted_acceleration), + predicted_acceleration + ) + + return modified_acceleration + def predict_positions( self, current_positions: torch.tensor, nparticles_per_example: torch.tensor, particle_types: torch.tensor, material_property: torch.tensor = None, + rigid_particle_id: int = None, + ) -> torch.tensor: """Predict position based on acceleration. @@ -286,6 +338,13 @@ def predict_positions( predicted_normalized_acceleration = self._encode_process_decode( node_features, edge_index, edge_features ) + + # Apply rigid body constraint if rigid_particle_id is provided + if rigid_particle_id is not None: + predicted_normalized_acceleration = self._enforce_rigid_body_acceleration( + predicted_normalized_acceleration, particle_types, rigid_particle_id + ) + next_positions = self._decoder_postprocessor( predicted_normalized_acceleration, current_positions ) @@ -299,6 +358,7 @@ def predict_accelerations( nparticles_per_example: torch.tensor, particle_types: torch.tensor, material_property: torch.tensor = None, + rigid_particle_id: int = None, ): """Produces normalized and predicted acceleration targets. @@ -339,6 +399,12 @@ def predict_accelerations( node_features, edge_index, edge_features ) + # Apply rigid body constraint if rigid_particle_id is provided + if rigid_particle_id is not None: + predicted_normalized_acceleration = self._enforce_rigid_body_acceleration( + predicted_normalized_acceleration, particle_types, rigid_particle_id + ) + # Calculate the target acceleration, using an `adjusted_next_position `that # is shifted by the noise in the last input position. next_position_adjusted = next_positions + position_sequence_noise[:, -1] diff --git a/gns/train.py b/gns/train.py index 39c57b0..ea6ca35 100644 --- a/gns/train.py +++ b/gns/train.py @@ -63,6 +63,7 @@ def rollout( nparticles_per_example=[n_particles_per_example], particle_types=particle_types, material_property=material_property, + rigid_particle_id=cfg.data.rigid_particle_id, ) # Update kinematic particles from prescribed trajectory. @@ -615,6 +616,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist): if n_features == 3 else None ), + rigid_particle_id=cfg.data.rigid_particle_id, ) if (