Skip to content

Implement 3D rigid body prediction (no Kabsch alignment) #102

@sarahnator

Description

@sarahnator

Description and motivation

Interate rigid-body updates into GNS to enable complex multi-body simulation. The goal is to ensure rigid bodies have stable IDs, and to apply rigid motion updates without Kabsch alignment by using the average angular momentum and per-particle angular rotation.

Rigid Body Motion Integration Plan

1. Data format and metadata

  • Add body_id per particle in dataset files. For each trajectory, store an array of shape (nparticles,) with integer IDs identifying rigid bodies. Non-rigid particles can share a sentinel ID (e.g., -1) or each get unique IDs.
  • Update data readers in gns/particle_data_loader.py to return body_id alongside positions, particle_type, and optional material_property.
  • Update config/metadata to declare whether body_id exists in the dataset and what sentinel (if any) denotes non-rigid particles. This keeps data loading robust for mixed datasets.

Example layout for one trajectory item (conceptual):

# positions: (timesteps, nparticles, dim)
# particle_type: scalar or (nparticles,)
# material_property: scalar or (nparticles,)
# body_id: (nparticles,)
data.append((positions, particle_type, material_property, body_id))

2. Data loader changes (gns/particle_data_loader.py)

Extend samples/trajectories and collation to carry body_id:

# inside ParticleDataset._get_sample
positions = self.data[trajectory_idx][0][
    time_idx - self.input_sequence_length : time_idx
]
positions = np.transpose(positions, (1, 0, 2))
particle_type = np.full(positions.shape[0], self.data[trajectory_idx][1], dtype=int)
body_id = self.data[trajectory_idx][3]  # shape (nparticles,)

if self.material_property_as_feature:
    material_property = np.full(positions.shape[0], self.data[trajectory_idx][2], dtype=float)
    features = (positions, particle_type, material_property, body_id, positions.shape[0])
else:
    features = (positions, particle_type, body_id, positions.shape[0])
# inside ParticleDataset._get_trajectory
positions, particle_type, material_property, body_id = self.data[idx]
positions = np.transpose(positions, (1, 0, 2))
particle_type = np.full(positions.shape[0], particle_type, dtype=int)
body_id = np.asarray(body_id, dtype=int)
n_particles_per_example = positions.shape[0]

trajectory = (
    torch.tensor(positions).to(torch.float32).contiguous(),
    torch.tensor(particle_type).contiguous(),
    torch.tensor(material_property).to(torch.float32).contiguous(),
    torch.tensor(body_id).contiguous(),
    n_particles_per_example,
)
# inside collate_fn_sample
position_list, particle_type_list, material_property_list, body_id_list = [], [], [], []
n_particles_per_example_list = []

for feature in features:
    position_list.append(feature[0])
    particle_type_list.append(feature[1])
    body_id_list.append(feature[2] if len(feature) == 4 else feature[3])
    if len(feature) == 5:
        material_property_list.append(feature[2])
        n_particles_per_example_list.append(feature[4])
    else:
        n_particles_per_example_list.append(feature[3])

collated_features = (
    torch.tensor(np.vstack(position_list)).to(torch.float32).contiguous(),
    torch.tensor(np.concatenate(particle_type_list)).contiguous(),
    torch.tensor(np.concatenate(body_id_list)).contiguous(),
    torch.tensor(n_particles_per_example_list).contiguous(),
)

3. Simulator interface (gns/learned_simulator.py)

Accept body_id but keep it out of learned node features:

def _encoder_preprocessor(..., particle_types, body_id=None, material_property=None):
    # unchanged feature construction for the network
    ...
    return node_features, edge_index, edge_features, body_id

4. Rigid-body update hook (gns/learned_simulator.py)

Add a helper to enforce rigid motion after the network predicts velocities:

from utils.rigid_body_motion import step_rigid_bodies

def _apply_rigid_body_update(self, positions, velocities, body_id, masses, dt):
    if body_id is None:
        return positions, velocities
    body_id_np = body_id.detach().cpu().numpy()
    pos_np = positions.detach().cpu().numpy()
    vel_np = velocities.detach().cpu().numpy()
    pos_new, vel_new = step_rigid_bodies(pos_np, vel_np, body_id_np, masses, dt)
    return (
        torch.tensor(pos_new, device=positions.device, dtype=positions.dtype),
        torch.tensor(vel_new, device=velocities.device, dtype=velocities.dtype),
    )

5. Rollout integration (gns/render_rollout.py)

Pass body_id and call the rigid-body post-step:

# after predicting velocities or next positions
positions, velocities = simulator._apply_rigid_body_update(
    positions, velocities, body_id, masses=1.0, dt=dt
)

6. Tests / sanity checks (test/)

Example shape-preservation test:

def test_rigid_body_preserves_distances():
    positions = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=float)
    velocities = np.array([[0, 1, 0], [0, 1, 0], [-1, 0, 0]], dtype=float)
    body_id = np.array([0, 0, 0], dtype=int)
    new_positions, _ = step_rigid_bodies(positions, velocities, body_id, masses=1.0, dt=0.1)
    d0 = np.linalg.norm(positions[0] - positions[1])
    d1 = np.linalg.norm(new_positions[0] - new_positions[1])
    assert np.isclose(d0, d1)

7. Example usage

positions, velocities = step_rigid_bodies(
    positions, velocities, body_id, masses=1.0, dt=0.01
)

Notes / design choices

  • This plan avoids Kabsch alignment entirely; rotation is computed from average angular momentum and per-particle angular rotation, as in utils/rigid_body_motion.py.
  • For mixed scenes, recommend using body_id = -1 for non-rigid particles and filtering them out in the rigid-body update.

Full example implementation (utils/rigid_body_motion.py)

import numpy as np


def _skew(vec):
    """Return the skew-symmetric matrix for a 3D vector."""
    x, y, z = vec
    return np.array([[0.0, -z, y], [z, 0.0, -x], [-y, x, 0.0]], dtype=float)


def _rotation_from_omega(omega, dt, eps=1e-12):
    """Rodrigues rotation for angular velocity over dt."""
    angle = np.linalg.norm(omega) * dt
    if angle < eps:
        return np.eye(3) + _skew(omega * dt)
    axis = omega / np.linalg.norm(omega)
    k = _skew(axis)
    return np.eye(3) + np.sin(angle) * k + (1.0 - np.cos(angle)) * (k @ k)


def step_rigid_bodies(positions, velocities, body_ids, masses, dt, eps=1e-8):
    """
    Advance particles as rigid bodies.

    - No Kabsch alignment: angular velocity is estimated from average angular
      momentum and average per-particle angular rotation.
    - body_ids identifies which particles belong to which rigid body.
    """
    positions = np.asarray(positions, dtype=float)
    velocities = np.asarray(velocities, dtype=float)
    body_ids = np.asarray(body_ids)
    if np.isscalar(masses):
        masses = np.full((positions.shape[0],), float(masses))
    else:
        masses = np.asarray(masses, dtype=float)

    new_positions = positions.copy()
    new_velocities = velocities.copy()

    for body_id in np.unique(body_ids):
        mask = body_ids == body_id
        x = positions[mask]
        v = velocities[mask]
        m = masses[mask]

        total_m = np.sum(m)
        if total_m <= 0:
            continue

        com = np.sum(x * m[:, None], axis=0) / total_m
        v_com = np.sum(v * m[:, None], axis=0) / total_m
        r = x - com

        # Angular momentum estimate.
        L = np.sum(np.cross(r, v * m[:, None]), axis=0)
        inertia = np.zeros((3, 3), dtype=float)
        for ri, mi in zip(r, m):
            inertia += mi * ((np.dot(ri, ri) * np.eye(3)) - np.outer(ri, ri))
        omega_L = np.linalg.solve(inertia + eps * np.eye(3), L)

        # Average per-particle angular rotation.
        denom = np.sum(r * r, axis=1) + eps
        omega_particles = np.cross(r, v) / denom[:, None]
        omega_avg = np.sum(omega_particles * m[:, None], axis=0) / total_m

        # Combine both estimates.
        omega = 0.5 * (omega_L + omega_avg)

        R = _rotation_from_omega(omega, dt, eps=eps)
        r_new = (R @ r.T).T
        com_new = com + v_com * dt

        new_positions[mask] = com_new + r_new
        new_velocities[mask] = v_com + np.cross(omega, r_new)

    return new_positions, new_velocities

Files to touch

  • gns/particle_data_loader.py (add body_id)
  • gns/learned_simulator.py (pass and apply body_id)
  • gns/render_rollout.py (pass body_id during rollouts)
  • test/ (add a small rigid-body unit test)
  • docs/rigid_body_motion_integration.md (this integration plan)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions