Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions gns/learned_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,64 @@ def _decoder_postprocessor(
new_position = most_recent_position + new_velocity # * dt = 1
return new_position

def _enforce_rigid_body_acceleration(
self,
predicted_acceleration: torch.tensor,
particle_types: torch.tensor,
rigid_particle_id: int,
) -> torch.tensor:
"""
Enforce that all rigid particles have the same acceleration (rigid body constraint).

This implements translational rigidity by averaging accelerations across all
rigid particles and assigning the same acceleration to each rigid particle.

Args:
predicted_acceleration: Predicted accelerations with shape (nparticles, dim).
particle_types: Particle types with shape (nparticles).
rigid_particle_id: The particle type ID for rigid particles.

Returns:
torch.tensor: Modified accelerations with rigid body constraint applied.
"""
# Create a mask for rigid particles
rigid_mask = (particle_types == rigid_particle_id) # shape: (nparticles,)

# If there are no rigid particles, return original accelerations
if rigid_mask.sum() == 0:
return predicted_acceleration

# Compute the average acceleration across all rigid particles
# rigid_mask needs to be expanded to match acceleration dimensions
rigid_mask_expanded = rigid_mask.unsqueeze(-1) # shape: (nparticles, 1)

# Sum accelerations of rigid particles
rigid_acc_sum = (predicted_acceleration * rigid_mask_expanded).sum(dim=0, keepdim=True)

# Count rigid particles
num_rigid = rigid_mask.sum()

# Compute average acceleration
avg_rigid_acceleration = rigid_acc_sum / num_rigid # shape: (1, dim)

# Replace all rigid particle accelerations with the average
# Use torch.where to conditionally replace
modified_acceleration = torch.where(
rigid_mask_expanded,
avg_rigid_acceleration.expand_as(predicted_acceleration),
predicted_acceleration
)

return modified_acceleration

def predict_positions(
self,
current_positions: torch.tensor,
nparticles_per_example: torch.tensor,
particle_types: torch.tensor,
material_property: torch.tensor = None,
rigid_particle_id: int = None,

) -> torch.tensor:
"""Predict position based on acceleration.

Expand Down Expand Up @@ -286,6 +338,13 @@ def predict_positions(
predicted_normalized_acceleration = self._encode_process_decode(
node_features, edge_index, edge_features
)

# Apply rigid body constraint if rigid_particle_id is provided
if rigid_particle_id is not None:
predicted_normalized_acceleration = self._enforce_rigid_body_acceleration(
predicted_normalized_acceleration, particle_types, rigid_particle_id
)

next_positions = self._decoder_postprocessor(
predicted_normalized_acceleration, current_positions
)
Expand All @@ -299,6 +358,7 @@ def predict_accelerations(
nparticles_per_example: torch.tensor,
particle_types: torch.tensor,
material_property: torch.tensor = None,
rigid_particle_id: int = None,
):
"""Produces normalized and predicted acceleration targets.

Expand Down Expand Up @@ -339,6 +399,12 @@ def predict_accelerations(
node_features, edge_index, edge_features
)

# Apply rigid body constraint if rigid_particle_id is provided
if rigid_particle_id is not None:
predicted_normalized_acceleration = self._enforce_rigid_body_acceleration(
predicted_normalized_acceleration, particle_types, rigid_particle_id
)

# Calculate the target acceleration, using an `adjusted_next_position `that
# is shifted by the noise in the last input position.
next_position_adjusted = next_positions + position_sequence_noise[:, -1]
Expand Down
11 changes: 8 additions & 3 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def rollout(
nparticles_per_example=[n_particles_per_example],
particle_types=particle_types,
material_property=material_property,
rigid_particle_id=cfg.data.rigid_particle_id,
)

# Update kinematic particles from prescribed trajectory.
Expand Down Expand Up @@ -202,7 +203,7 @@ def predict(device: str, cfg: DictConfig):
example_rollout["loss"] = loss.mean()
filename = f"{cfg.output.filename}_ex{example_i}.pkl"
filename_render = f"{cfg.output.filename}_ex{example_i}"
filename = os.path.join(cfg.output.path, f"{filename_render}.pkl")
filename = os.path.join(cfg.output.path, filename)
with open(filename, "wb") as f:
pickle.dump(example_rollout, f)
if cfg.rendering.mode:
Expand Down Expand Up @@ -347,9 +348,11 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device, use_d
rank,
)
if use_dist:
simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank])
simulator = torch.compile(serial_simulator)
simulator = DDP(simulator.to("cuda"), device_ids=[rank])
else:
simulator = serial_simulator.to("cuda")
simulator = torch.compile(serial_simulator)
simulator = simulator.to("cuda")
optimizer = torch.optim.Adam(
simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size
)
Expand All @@ -361,6 +364,7 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device, use_d
cfg.data.noise_std,
device,
)
simulator = torch.compile(simulator)
optimizer = torch.optim.Adam(
simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size
)
Expand Down Expand Up @@ -612,6 +616,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
if n_features == 3
else None
),
rigid_particle_id=cfg.data.rigid_particle_id,
)

if (
Expand Down