From 7ed97e910fed0c19f6cc1f8687b1522c0f43ba78 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 10:25:41 +0200 Subject: [PATCH 01/57] Remove net forces in ASE calculator --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8c26477a..3aab1f17 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -465,6 +465,10 @@ def calculate( forces_values = ( outputs["non_conservative_forces"].block().values.detach() ) + # remove any spurious net force + forces_values = forces_values - forces_values.mean( + dim=0, keepdim=True + ) else: forces_values = -system.positions.grad forces_values = forces_values.reshape(-1, 3) @@ -587,6 +591,12 @@ def compute_energy( results_as_numpy_arrays["forces"], split_indices, axis=0 ) + # remove net forces + results_as_numpy_arrays["forces"] = [ + f - f.mean(axis=0, keepdims=True) + for f in results_as_numpy_arrays["forces"] + ] + if all(atoms.pbc.all() for atoms in atoms_list): results_as_numpy_arrays["stress"] = [ s From 159b174aed6b22c910baf617deed96671449b610 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:28:04 +0200 Subject: [PATCH 02/57] Add SO3 and O3 averaging calculators --- .../metatomic/torch/ase_calculator.py | 247 +++++++++++++++++- 1 file changed, 244 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3aab1f17..5fe45bff 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,13 +2,15 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import metatensor.torch import numpy as np import torch import vesin from metatensor.torch import Labels, TensorBlock, TensorMap +from scipy.integrate import lebedev_rule +from scipy.spatial.transform import Rotation from torch.profiler import record_function from . import ( @@ -31,7 +33,6 @@ all_properties as ALL_ASE_PROPERTIES, ) - FilePath = Union[str, bytes, pathlib.PurePath] LOGGER = logging.getLogger(__name__) @@ -593,7 +594,7 @@ def compute_energy( # remove net forces results_as_numpy_arrays["forces"] = [ - f - f.mean(axis=0, keepdims=True) + f - f.mean(axis=0, keepdims=True) for f in results_as_numpy_arrays["forces"] ] @@ -824,3 +825,243 @@ def _full_3x3_to_voigt_6_stress(stress): (stress[0, 1] + stress[1, 0]) / 2.0, ] ) + + +class SO3AveragedCalculator(ase.calculators.calculator.Calculator): + """ + Take a MetatomicCalculator and average its predictions over a + Lebedev (S^2) x Uniform (S^1) grid of rotations in SO(3). + """ + + implemented_properties = ["energy", "forces", "stress"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + lebedev_order: int = 3, + n_inplane_rotations: int = 4, + batch_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.base_calculator = base_calculator + self.lebedev_order = lebedev_order + self.n_inplane_rotations = n_inplane_rotations + + self.so3_quadrature_rotations = _get_so3_quadrature( + lebedev_order, n_inplane_rotations + ) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.so3_quadrature_rotations) + ) + + def calculate(self, atoms, properties, system_changes): + super().calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + + if len(self.so3_quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.so3_quadrature_rotations) + batch_size = ( + self.batch_size + if self.batch_size is not None + else len(rotated_atoms_list) + ) + batches = [ + rotated_atoms_list[i : i + batch_size] + for i in range(0, len(rotated_atoms_list), batch_size) + ] + results: Dict[str, Any] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, compute_forces_and_stresses + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " + "parameters while initializing the calculator." + f"Full error message: {e}" + ) + + results = _compute_rotational_average( + results, self.so3_quadrature_rotations + ) + self.results.update(results) + + +class O3AveragedCalculator(ase.calculators.calculator.Calculator): + """ + Take a MetatomicCalculator and average its predictions over a + Lebedev (S^2) x Uniform (S^1) grid of rotations in O(3). + """ + + implemented_properties = ["energy", "forces", "stress"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + lebedev_order: int = 3, + n_inplane_rotations: int = 4, + batch_size: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.base_calculator = base_calculator + self.lebedev_order = lebedev_order + self.n_inplane_rotations = n_inplane_rotations + + self.o3_quadrature_rotations = _get_o3_quadrature( + lebedev_order, n_inplane_rotations + ) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.o3_quadrature_rotations) + ) + + def calculate(self, atoms, properties, system_changes): + super().calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + + if len(self.o3_quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.o3_quadrature_rotations) + batches = [ + rotated_atoms_list[i : i + self.batch_size] + for i in range(0, len(rotated_atoms_list), self.batch_size) + ] + results: Dict[str, Any] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, compute_forces_and_stresses + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " + "parameters while initializing the calculator." + f"Full error message: {e}" + ) + + results = _compute_rotational_average(results, self.o3_quadrature_rotations) + self.results.update(results) + + +def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + rotated_atoms_list = [] + has_cell = atoms.cell is not None and atoms.cell.rank > 0 + for rot in rotations: + new_atoms = atoms.copy() + new_atoms.positions = new_atoms.positions @ rot.T + if has_cell: + new_atoms.cell = new_atoms.cell @ rot.T + rotated_atoms_list.append(new_atoms) + return rotated_atoms_list + + +def _get_so3_quadrature(lebedev_order: int, n_rotations: int): + """ + Lebedev(S^2) x uniform angle quadrature on SO(3). + """ + + # Lebedev nodes (X: (3, M)) + X, _ = lebedev_rule(lebedev_order) + + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, K) # (N,) + B = np.repeat(beta, K) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + Rmats = Rot.as_matrix() # (N, 3, 3) + + return Rmats + + +def _get_o3_quadrature(lebedev_order: int, n_rotations: int): + """ + Lebedev(S^2) x uniform angle quadrature on O(3). + Returns an array of shape (2N, 3, 3) with orthogonal matrices, + the first N in SO(3), the next N in its coset with inversion. + """ + # Lebedev nodes (X: (3, M)) + X, _ = lebedev_rule(lebedev_order) + + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, K) # (N,) + B = np.repeat(beta, K) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + + return R_o3 + + +def _compute_rotational_average(results, rotations): + R = np.asarray(rotations) # (B,3,3) + out = {} + if "energy" in results: + arr = np.asarray(results["energy"]) + out["energy"] = arr.mean() + out["energy_rot_std"] = arr.std() + if "forces" in results: + F = np.stack(results["forces"], axis=0) # (B,N,3) + F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) + out["forces"] = F_back.mean(axis=0) + out["forces_rot_std"] = F_back.std(axis=0) + if "stress" in results: + S = np.stack(results["stress"], axis=0) # (B,3,3) + RT = np.swapaxes(R, 1, 2) + tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) + S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) + out["stress"] = S_back.mean(axis=0) + out["stress_rot_std"] = S_back.std(axis=0) + return out From 484a4a2a934f3c8f4cde60aa06e77ea7ee354c9c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:38:01 +0200 Subject: [PATCH 03/57] Update --- .../metatomic/torch/ase_calculator.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 1e57dddc..08eabf9e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -188,9 +188,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -863,7 +863,15 @@ def _full_3x3_to_voigt_6_stress(stress): ) -<<<<<<< HEAD +def _get_energy_uncertainty_output(): + return ModelOutput( + quantity="energy", + unit="eV", + per_atom=True, + explicit_gradients=[], + ) + + class SO3AveragedCalculator(ase.calculators.calculator.Calculator): """ Take a MetatomicCalculator and average its predictions over a @@ -1102,12 +1110,3 @@ def _compute_rotational_average(results, rotations): out["stress"] = S_back.mean(axis=0) out["stress_rot_std"] = S_back.std(axis=0) return out -======= -def _get_energy_uncertainty_output(): - return ModelOutput( - quantity="energy", - unit="eV", - per_atom=True, - explicit_gradients=[], - ) ->>>>>>> main From 84706dc29e53820b5043a909f08b4aaf4c23be99 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 13:47:02 +0200 Subject: [PATCH 04/57] Fix typing --- .../metatomic/torch/ase_calculator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 08eabf9e..8d50aa8e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import metatensor.torch import numpy as np @@ -188,9 +188,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs @@ -918,7 +918,7 @@ def calculate(self, atoms, properties, system_changes): rotated_atoms_list[i : i + batch_size] for i in range(0, len(rotated_atoms_list), batch_size) ] - results: Dict[str, Any] = {} + results: Dict[str, np.ndarray] = {} for batch in batches: try: batch_results = self.base_calculator.compute_energy( @@ -986,7 +986,7 @@ def calculate(self, atoms, properties, system_changes): rotated_atoms_list[i : i + self.batch_size] for i in range(0, len(rotated_atoms_list), self.batch_size) ] - results: Dict[str, Any] = {} + results: Dict[str, np.ndarray] = {} for batch in batches: try: batch_results = self.base_calculator.compute_energy( From ea98c0d84d087bb0f871a3eec64c9a50e2cdca81 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 10 Sep 2025 17:19:12 +0200 Subject: [PATCH 05/57] Make scipy an optional dependency --- .../metatomic/torch/ase_calculator.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8d50aa8e..3cbd8107 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -9,8 +9,6 @@ import torch import vesin from metatensor.torch import Labels, TensorBlock, TensorMap -from scipy.integrate import lebedev_rule -from scipy.spatial.transform import Rotation from torch.profiler import record_function from . import ( @@ -888,6 +886,14 @@ def __init__( batch_size: Optional[int] = None, **kwargs, ): + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the SO3AveragedCalculator, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + super().__init__(**kwargs) self.base_calculator = base_calculator @@ -939,6 +945,12 @@ def calculate(self, atoms, properties, system_changes): f"Full error message: {e}" ) + # Clean up + try: + torch.cuda.empty_cache() + except Exception: + pass + results = _compute_rotational_average( results, self.so3_quadrature_rotations ) @@ -961,6 +973,14 @@ def __init__( batch_size: Optional[int] = None, **kwargs, ): + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the SO3AveragedCalculator, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + super().__init__(**kwargs) self.base_calculator = base_calculator @@ -1007,6 +1027,12 @@ def calculate(self, atoms, properties, system_changes): f"Full error message: {e}" ) + # Clean up + try: + torch.cuda.empty_cache() + except Exception: + pass + results = _compute_rotational_average(results, self.o3_quadrature_rotations) self.results.update(results) @@ -1027,6 +1053,8 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): """ Lebedev(S^2) x uniform angle quadrature on SO(3). """ + from scipy.integrate import lebedev_rule + from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) @@ -1060,6 +1088,9 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): Returns an array of shape (2N, 3, 3) with orthogonal matrices, the first N in SO(3), the next N in its coset with inversion. """ + from scipy.integrate import lebedev_rule + from scipy.spatial.transform import Rotation + # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) From d375984f870660567e5f937e120fb302b306ed2c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Tue, 16 Sep 2025 14:31:43 +0200 Subject: [PATCH 06/57] Update rotation routines --- .../metatomic/torch/ase_calculator.py | 124 ++++++++++++------ 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3cbd8107..4a4113e6 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -971,6 +971,7 @@ def __init__( lebedev_order: int = 3, n_inplane_rotations: int = 4, batch_size: Optional[int] = None, + return_o3_samples=False, **kwargs, ): try: @@ -987,7 +988,7 @@ def __init__( self.lebedev_order = lebedev_order self.n_inplane_rotations = n_inplane_rotations - self.o3_quadrature_rotations = _get_o3_quadrature( + self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( lebedev_order, n_inplane_rotations ) @@ -995,6 +996,8 @@ def __init__( batch_size if batch_size is not None else len(self.o3_quadrature_rotations) ) + self.return_o3_samples = return_o3_samples + def calculate(self, atoms, properties, system_changes): super().calculate(atoms, properties, system_changes) @@ -1033,8 +1036,13 @@ def calculate(self, atoms, properties, system_changes): except Exception: pass - results = _compute_rotational_average(results, self.o3_quadrature_rotations) - self.results.update(results) + self.results.update( + _compute_rotational_average( + results, self.o3_quadrature_rotations, self.o3_quadrature_weights + ) + ) + if self.return_o3_samples: + self.results["o3_samples"] = results def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: @@ -1044,7 +1052,10 @@ def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Ato new_atoms = atoms.copy() new_atoms.positions = new_atoms.positions @ rot.T if has_cell: - new_atoms.cell = new_atoms.cell @ rot.T + new_atoms.set_cell( + new_atoms.cell.array @ rot.T, scale_atoms=False, apply_constraint=False + ) + new_atoms.wrap() rotated_atoms_list.append(new_atoms) return rotated_atoms_list @@ -1054,7 +1065,6 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): Lebedev(S^2) x uniform angle quadrature on SO(3). """ from scipy.integrate import lebedev_rule - from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) X, _ = lebedev_rule(lebedev_order) @@ -1066,19 +1076,13 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) - # Build all combinations (alpha_i, beta_i, gamma_j) - A = np.repeat(alpha, K) # (N,) - B = np.repeat(beta, K) # (N,) - G = np.tile(gamma, alpha.size) # (N,) - - # Compose ZYZ rotations - Rot = ( - Rotation.from_euler("z", A) - * Rotation.from_euler("y", B) - * Rotation.from_euler("z", G) - ) + Rot = _rotations_from_angles(alpha, beta, gamma) Rmats = Rot.as_matrix() # (N, 3, 3) + # Re-orthogonalize the rotation matrices to avoid numerical issues + U, _, Vt = np.linalg.svd(Rmats, full_matrices=False) + Rmats = U @ Vt + return Rmats @@ -1089,11 +1093,9 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): the first N in SO(3), the next N in its coset with inversion. """ from scipy.integrate import lebedev_rule - from scipy.spatial.transform import Rotation # Lebedev nodes (X: (3, M)) - X, _ = lebedev_rule(lebedev_order) - + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) @@ -1101,9 +1103,26 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + Rot = _rotations_from_angles(alpha, beta, gamma) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + w_o3 = np.concatenate([0.5 * w_so3, 0.5 * w_so3], axis=0) + + return R_o3, w_o3 + + +def _rotations_from_angles(alpha, beta, gamma): + from scipy.spatial.transform import Rotation + # Build all combinations (alpha_i, beta_i, gamma_j) - A = np.repeat(alpha, K) # (N,) - B = np.repeat(beta, K) # (N,) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) G = np.tile(gamma, alpha.size) # (N,) # Compose ZYZ rotations in SO(3) @@ -1112,32 +1131,59 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): * Rotation.from_euler("y", B) * Rotation.from_euler("z", G) ) - R_so3 = Rot.as_matrix() # (N, 3, 3) - # Extend to O(3) by appending inversion * R - P = -np.eye(3) - R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + return Rot + + +def _compute_rotational_average(results, rotations, weights): + R = rotations + B = R.shape[0] + w = weights + w = w / w.sum() + + def _wreshape(x): + return w.reshape((B,) + (1,) * (x.ndim - 1)) - return R_o3 + def _wmean(x): + return np.sum(_wreshape(x) * x, axis=0) + def _wstd(x): + mu = _wmean(x) + return np.sqrt(np.sum(_wreshape(x) * (x - mu) ** 2, axis=0)) -def _compute_rotational_average(results, rotations): - R = np.asarray(rotations) # (B,3,3) out = {} + + # Energy (B,) if "energy" in results: - arr = np.asarray(results["energy"]) - out["energy"] = arr.mean() - out["energy_rot_std"] = arr.std() + E = np.asarray(results["energy"], dtype=float) + if E.shape != (B,): + raise ValueError(f"energy must be shape ({B},), got {E.shape}") + out["energy"] = _wmean(E) + out["energy_rot_std"] = _wstd(E) + + # Forces (B,N,3) from rotated structures: back-rotate with R^T F' if "forces" in results: - F = np.stack(results["forces"], axis=0) # (B,N,3) - F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) - out["forces"] = F_back.mean(axis=0) - out["forces_rot_std"] = F_back.std(axis=0) + F = np.asarray(results["forces"], dtype=float) # (B,N,3) + if F.ndim == 2: + F = F[np.newaxis, ...] + if F.shape[0] != B or F.shape[-1] != 3: + raise ValueError(f"forces must be (B,N,3); got {F.shape} with B={B}") + RT = np.swapaxes(R, 1, 2) + F_back = np.einsum("bnj,bjk->bnk", F, RT, optimize=True) # R^T * F' + out["forces"] = _wmean(F_back) # (N,3) + out["forces_rot_std"] = _wstd(F_back) # (N,3) + + # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: - S = np.stack(results["stress"], axis=0) # (B,3,3) + S = np.asarray(results["stress"], dtype=float) # (B,3,3) + if S.ndim == 2: + S = S[np.newaxis, ...] + if S.shape != (B, 3, 3): + raise ValueError(f"stress must be (B,3,3); got {S.shape} with B={B}") RT = np.swapaxes(R, 1, 2) tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) - S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) - out["stress"] = S_back.mean(axis=0) - out["stress_rot_std"] = S_back.std(axis=0) + S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) # R^T S' R + out["stress"] = _wmean(S_back) # (3,3) + out["stress_rot_std"] = _wstd(S_back) # (3,3) + return out From 57c953e4f3e2ee8a89d114f592ce1f92928bbe88 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 09:39:53 +0200 Subject: [PATCH 07/57] Simplify args --- .../metatomic/torch/ase_calculator.py | 53 +++++++++++++++++-- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 4a4113e6..d38979ed 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -968,8 +968,7 @@ class O3AveragedCalculator(ase.calculators.calculator.Calculator): def __init__( self, base_calculator: MetatomicCalculator, - lebedev_order: int = 3, - n_inplane_rotations: int = 4, + l_max: int = 3, batch_size: Optional[int] = None, return_o3_samples=False, **kwargs, @@ -985,9 +984,13 @@ def __init__( super().__init__(**kwargs) self.base_calculator = base_calculator - self.lebedev_order = lebedev_order - self.n_inplane_rotations = n_inplane_rotations + if l_max > 131: + raise ValueError( + f"l_max={l_max} is too large, the maximum supported value is 131" + ) + self.l_max = l_max + lebedev_order, n_inplane_rotations = choose_quadrature(l_max) self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( lebedev_order, n_inplane_rotations ) @@ -1045,6 +1048,48 @@ def calculate(self, atoms, properties, system_changes): self.results["o3_samples"] = results +def choose_quadrature(L_max): + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: rotated_atoms_list = [] has_cell = atoms.cell is not None and atoms.cell.rank > 0 From b6354b73c8303b031ef79288378cf0363f077830 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 10:57:24 +0200 Subject: [PATCH 08/57] Fix bug --- .../metatomic/torch/ase_calculator.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index d38979ed..ffd5dfc5 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -1201,30 +1201,19 @@ def _wstd(x): # Energy (B,) if "energy" in results: E = np.asarray(results["energy"], dtype=float) - if E.shape != (B,): - raise ValueError(f"energy must be shape ({B},), got {E.shape}") out["energy"] = _wmean(E) out["energy_rot_std"] = _wstd(E) # Forces (B,N,3) from rotated structures: back-rotate with R^T F' if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) - if F.ndim == 2: - F = F[np.newaxis, ...] - if F.shape[0] != B or F.shape[-1] != 3: - raise ValueError(f"forces must be (B,N,3); got {F.shape} with B={B}") - RT = np.swapaxes(R, 1, 2) - F_back = np.einsum("bnj,bjk->bnk", F, RT, optimize=True) # R^T * F' + F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) # F' R out["forces"] = _wmean(F_back) # (N,3) out["forces_rot_std"] = _wstd(F_back) # (N,3) # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: S = np.asarray(results["stress"], dtype=float) # (B,3,3) - if S.ndim == 2: - S = S[np.newaxis, ...] - if S.shape != (B, 3, 3): - raise ValueError(f"stress must be (B,3,3); got {S.shape} with B={B}") RT = np.swapaxes(R, 1, 2) tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) # R^T S' R From 6a09b260998cb919bc5cd206a1c4b7013cd7b2f8 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 17 Sep 2025 13:07:19 +0200 Subject: [PATCH 09/57] Add group symmetrization --- .../metatomic/torch/ase_calculator.py | 142 +++++++++++++++++- 1 file changed, 141 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index ffd5dfc5..16d3a515 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -971,6 +971,7 @@ def __init__( l_max: int = 3, batch_size: Optional[int] = None, return_o3_samples=False, + apply_group_symmetry=False, **kwargs, ): try: @@ -1000,10 +1001,14 @@ def __init__( ) self.return_o3_samples = return_o3_samples + self.apply_group_symmetry = apply_group_symmetry def calculate(self, atoms, properties, system_changes): super().calculate(atoms, properties, system_changes) + if self.apply_group_symmetry: + Q_list, P_list = _get_group_operations(atoms) + compute_forces_and_stresses = "forces" in properties or "stress" in properties if len(self.o3_quadrature_rotations) > 0: @@ -1044,6 +1049,9 @@ def calculate(self, atoms, properties, system_changes): results, self.o3_quadrature_rotations, self.o3_quadrature_weights ) ) + + if self.apply_group_symmetry: + self.results.update(_average_over_group(self.results, Q_list, P_list)) if self.return_o3_samples: self.results["o3_samples"] = results @@ -1221,3 +1229,135 @@ def _wstd(x): out["stress_rot_std"] = _wstd(S_back) # (3,3) return out + + +def _get_group_operations( + atoms: ase.Atoms, symprec: float = 1e-6, angle_tolerance: float = -1.0 +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Extract point-group rotations Q_g (Cartesian, 3x3) and the corresponding + atom-index permutations P_g (N x N) induced by the space-group operations. + Returns Q_list, Cartesian rotation matrices of the point group, + and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + """ + try: + import spglib + except ImportError as e: + raise ImportError( + "spglib is required to use the O3AveragedCalculator with " + "`apply_group_symmetry=True`. Please install it with " + "`pip install spglib` or `conda install -c conda-forge spglib`" + ) from e + + # Lattice with column vectors a1,a2,a3 (spglib expects (cell, frac, Z)) + A = atoms.cell.array.T # (3,3) + frac = atoms.get_scaled_positions() # (N,3) in [0,1) + numbers = atoms.numbers + N = len(atoms) + + data = spglib.get_symmetry_dataset( + (atoms.cell.array, frac, numbers), + symprec=symprec, + angle_tolerance=angle_tolerance, + ) + R_frac = data.rotations # (n_ops, 3,3), integer + t_frac = data.translations # (n_ops, 3) + Z = numbers + + # Match fractional coords modulo 1 within a tolerance, respecting chemical species + def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): + d = np.abs(frac_ref - x_new) # (N,3) + d = np.minimum(d, 1.0 - d) # periodic distance + # Mask by identical species + mask = Z_ref == Z_i + if not np.any(mask): + raise RuntimeError("No matching species found while building permutation.") + # Choose argmin over max-norm within species + idx = np.where(mask)[0] + j = idx[np.argmin(np.max(d[idx], axis=1))] + + # Sanity check + if np.max(d[j]) > tol: + pass + return j + + Q_list, P_list = [], [] + seen = set() + Ainv = np.linalg.inv(A) + + for Rf, tf in zip(R_frac, t_frac): + # Cartesian rotation: Q = A Rf A^{-1} + Q = A @ Rf @ Ainv + # Deduplicate rotations (point group) by rounding + key = tuple(np.round(Q.flatten(), 12)) + if key in seen: + continue + seen.add(key) + + # Build the permutation P from i to j + P = np.zeros((N, N), dtype=int) + new_frac = (frac @ Rf.T + tf) % 1.0 # images after (Rf,tf) + for i in range(N): + j = _match_index(new_frac[i], frac, Z, Z[i]) + P[j, i] = 1 # column i maps to row j + + Q_list.append(Q.astype(float)) + P_list.append(P) + + return Q_list, P_list + + +def _average_over_group( + results: dict, Q_list: List[np.ndarray], P_list: List[np.ndarray] +) -> dict: + """ + Apply the point-group projector in output space. + """ + m = len(Q_list) + if m == 0: + # No symmetry found; return copies + out = {} + if "energy" in results: + out["energy_pg"] = float(results["energy"]) + if "forces" in results: + out["forces_pg"] = np.array(results["forces"], float, copy=True) + if "stress" in results: + S = np.array(results["stress"], float, copy=True) + S = 0.5 * (S + S.T) + out["stress_pg"] = S + out["stress_iso_pg"] = np.eye(3) * (np.trace(S) / 3.0) + out["stress_dev_pg"] = S - out["stress_iso_pg"] + return out + + out = {} + # Energy: unchanged by the projector (scalar invariant) + if "energy" in results: + out["energy"] = float(results["energy"]) + + # Forces: (N,3) row-vectors; projector: (1/|G|) \sum_g P_g^T F Q_g + if "forces" in results: + F = np.asarray(results["forces"], float) + if F.ndim != 2 or F.shape[1] != 3: + raise ValueError(f"'forces' must be (N,3), got {F.shape}") + acc = np.zeros_like(F) + for Q, P in zip(Q_list, P_list): + acc += P.T @ (F @ Q) + out["forces"] = acc / m + + # Stress: (3,3); projector: (1/|G|) \sum_g Q_g^T S Q_g + if "stress" in results: + S = np.asarray(results["stress"], float) + if S.shape != (3, 3): + raise ValueError(f"'stress' must be (3,3), got {S.shape}") + S = 0.5 * (S + S.T) # symmetrize just in case + acc = np.zeros_like(S) + for Q in Q_list: + acc += Q.T @ S @ Q + S_pg = acc / m + out["stress"] = S_pg + # # Expose L=0 projection and deviatoric part for debugging + # S_iso = np.trace(S_pg) / 3.0 + # out["stress_iso_pg"] = np.eye(3) * S_iso + # out["stress_dev_pg"] = S_pg - out["stress_iso_pg"] + + return out From 704b923b39556cf7dc8230c9ae338834bd72dc6f Mon Sep 17 00:00:00 2001 From: ppegolo Date: Sat, 20 Sep 2025 18:48:25 +0200 Subject: [PATCH 10/57] small change --- .../metatomic/torch/ase_calculator.py | 61 +++++++++++++------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 16d3a515..8fe059a4 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -991,10 +991,15 @@ def __init__( ) self.l_max = l_max - lebedev_order, n_inplane_rotations = choose_quadrature(l_max) - self.o3_quadrature_rotations, self.o3_quadrature_weights = _get_o3_quadrature( - lebedev_order, n_inplane_rotations - ) + if l_max > 0: + lebedev_order, n_inplane_rotations = choose_quadrature(l_max) + self.o3_quadrature_rotations, self.o3_quadrature_weights = ( + _get_o3_quadrature(lebedev_order, n_inplane_rotations) + ) + else: + # no quadrature + self.o3_quadrature_rotations = np.array([np.eye(3)]) + self.o3_quadrature_weights = np.array([1.0]) self.batch_size = ( batch_size if batch_size is not None else len(self.o3_quadrature_rotations) @@ -1050,11 +1055,12 @@ def calculate(self, atoms, properties, system_changes): ) ) - if self.apply_group_symmetry: - self.results.update(_average_over_group(self.results, Q_list, P_list)) if self.return_o3_samples: self.results["o3_samples"] = results + if self.apply_group_symmetry: + self.results.update(_average_over_group(self.results, Q_list, P_list)) + def choose_quadrature(L_max): available = [ @@ -1120,23 +1126,21 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): from scipy.integrate import lebedev_rule # Lebedev nodes (X: (3, M)) - X, _ = lebedev_rule(lebedev_order) - + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) Rot = _rotations_from_angles(alpha, beta, gamma) - Rmats = Rot.as_matrix() # (N, 3, 3) + R_so3 = Rot.as_matrix() # (N, 3, 3) - # Re-orthogonalize the rotation matrices to avoid numerical issues - U, _, Vt = np.linalg.svd(Rmats, full_matrices=False) - Rmats = U @ Vt + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) - return Rmats + return R_so3, w_so3 def _get_o3_quadrature(lebedev_order: int, n_rotations: int): @@ -1151,7 +1155,8 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): X, w = lebedev_rule(lebedev_order) # w sums to 4*pi x, y, z = X alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) K = int(n_rotations) gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) @@ -1159,6 +1164,13 @@ def _get_o3_quadrature(lebedev_order: int, n_rotations: int): Rot = _rotations_from_angles(alpha, beta, gamma) R_so3 = Rot.as_matrix() # (N, 3, 3) + # rnd = np.random.uniform(size=(3, 3)) + # rnd = rnd - rnd.T + # import scipy.linalg + + # rnd = scipy.linalg.expm(-rnd) + # R_so3 = R_so3 @ rnd + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) @@ -1212,7 +1224,7 @@ def _wstd(x): out["energy"] = _wmean(E) out["energy_rot_std"] = _wstd(E) - # Forces (B,N,3) from rotated structures: back-rotate with R^T F' + # Forces (B,N,3) from rotated structures: back-rotate with F' R if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) # F' R @@ -1312,6 +1324,19 @@ def _average_over_group( ) -> dict: """ Apply the point-group projector in output space. + + Parameters + ---------- + results : dict + Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or 'stress' (3,3). + These are predictions for the *current* structure in the reference frame. + Q_list, P_list : outputs of _get_group_operations + + Returns + ------- + out : dict + Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. + For stress, also returns 'stress_iso_pg' (L=0) and 'stress_dev_pg'. """ m = len(Q_list) if m == 0: @@ -1323,7 +1348,7 @@ def _average_over_group( out["forces_pg"] = np.array(results["forces"], float, copy=True) if "stress" in results: S = np.array(results["stress"], float, copy=True) - S = 0.5 * (S + S.T) + # S = 0.5 * (S + S.T) out["stress_pg"] = S out["stress_iso_pg"] = np.eye(3) * (np.trace(S) / 3.0) out["stress_dev_pg"] = S - out["stress_iso_pg"] @@ -1349,7 +1374,7 @@ def _average_over_group( S = np.asarray(results["stress"], float) if S.shape != (3, 3): raise ValueError(f"'stress' must be (3,3), got {S.shape}") - S = 0.5 * (S + S.T) # symmetrize just in case + # S = 0.5 * (S + S.T) # symmetrize just in case acc = np.zeros_like(S) for Q in Q_list: acc += Q.T @ S @ Q From 971c22bbb13ff9069f474669361ac3fc1f82ff00 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 13:51:31 +0200 Subject: [PATCH 11/57] clean up and add tests --- .../metatomic/torch/ase_calculator.py | 286 ++++++-------- .../tests/symmetrized_ase_calculator.py | 354 ++++++++++++++++++ 2 files changed, 465 insertions(+), 175 deletions(-) create mode 100644 python/metatomic_torch/tests/symmetrized_ase_calculator.py diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 8fe059a4..d3027f92 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -186,9 +186,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -870,97 +870,35 @@ def _get_energy_uncertainty_output(): ) -class SO3AveragedCalculator(ase.calculators.calculator.Calculator): - """ - Take a MetatomicCalculator and average its predictions over a - Lebedev (S^2) x Uniform (S^1) grid of rotations in SO(3). - """ - - implemented_properties = ["energy", "forces", "stress"] - - def __init__( - self, - base_calculator: MetatomicCalculator, - lebedev_order: int = 3, - n_inplane_rotations: int = 4, - batch_size: Optional[int] = None, - **kwargs, - ): - try: - from scipy.integrate import lebedev_rule # noqa: F401 - except ImportError as e: - raise ImportError( - "scipy is required to use the SO3AveragedCalculator, please install " - "it with `pip install scipy` or `conda install scipy`" - ) from e - - super().__init__(**kwargs) - - self.base_calculator = base_calculator - self.lebedev_order = lebedev_order - self.n_inplane_rotations = n_inplane_rotations - - self.so3_quadrature_rotations = _get_so3_quadrature( - lebedev_order, n_inplane_rotations - ) - - self.batch_size = ( - batch_size if batch_size is not None else len(self.so3_quadrature_rotations) - ) - - def calculate(self, atoms, properties, system_changes): - super().calculate(atoms, properties, system_changes) - - compute_forces_and_stresses = "forces" in properties or "stress" in properties - - if len(self.so3_quadrature_rotations) > 0: - rotated_atoms_list = _rotate_atoms(atoms, self.so3_quadrature_rotations) - batch_size = ( - self.batch_size - if self.batch_size is not None - else len(rotated_atoms_list) - ) - batches = [ - rotated_atoms_list[i : i + batch_size] - for i in range(0, len(rotated_atoms_list), batch_size) - ] - results: Dict[str, np.ndarray] = {} - for batch in batches: - try: - batch_results = self.base_calculator.compute_energy( - batch, compute_forces_and_stresses - ) - for key, value in batch_results.items(): - results.setdefault(key, []) - results[key].extend( - [value] if isinstance(value, float) else value - ) - except torch.cuda.OutOfMemoryError as e: - raise RuntimeError( - "Out of memory error encountered during rotational averaging. " - "Please reduce the batch size or use lower rotational " - "averaging parameters. This can be done by setting the " - "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " - "parameters while initializing the calculator." - f"Full error message: {e}" - ) - - # Clean up - try: - torch.cuda.empty_cache() - except Exception: - pass - - results = _compute_rotational_average( - results, self.so3_quadrature_rotations - ) - self.results.update(results) - - -class O3AveragedCalculator(ase.calculators.calculator.Calculator): - """ - Take a MetatomicCalculator and average its predictions over a - Lebedev (S^2) x Uniform (S^1) grid of rotations in O(3). +class SymmetrizedCalculator(ase.calculators.calculator.Calculator): + r""" + Take a MetatomicCalculator and average its predictions to make it (approximately) + equivariant. + + The default is to average over a quadrature of the orthogonal group O(3) composed + this way: + + - Lebedev quadrature of the unit sphere (S^2) + - Equispaced sampling of the unit circle (S^1) + - Both proper and improper rotations are taken into account by including the + inversion operation (if ``include_inversion=True``) + + :param base_calculator: the MetatomicCalculator to be symmetrized + :param l_max: the maximum spherical harmonic degree that the model is expected to + be able to represent. This is used to choose the quadrature order. If ``0``, + no rotational averaging will be performed (it can be useful to average only over + the space group, see ``apply_group_symmetry``). + :param batch_size: number of rotated systems to evaluate at once. If ``None``, all + systems will be evaluated at once (this can lead to high memory usage). + :param include_inversion: if ``True``, the inversion operation will be included in + the averaging. This is required to average over the full orthogonal group O(3). + :param apply_group_symmetry: if ``True``, the results will be averaged over the + discrete space group of rotations for the input system. The group operations are + computed with spglib, and the average is performed after the O(3) averaging + (if any). + :param return_samples: if ``True``, the results of the base calculator on each + rotated system will be returned. Most useful for debugging. + :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor """ implemented_properties = ["energy", "forces", "stress"] @@ -970,10 +908,11 @@ def __init__( base_calculator: MetatomicCalculator, l_max: int = 3, batch_size: Optional[int] = None, - return_o3_samples=False, - apply_group_symmetry=False, - **kwargs, - ): + include_inversion: bool = True, + apply_group_symmetry: bool = False, + return_samples: bool = False, + **kwargs: Any, + ) -> None: try: from scipy.integrate import lebedev_rule # noqa: F401 except ImportError as e: @@ -990,34 +929,43 @@ def __init__( f"l_max={l_max} is too large, the maximum supported value is 131" ) self.l_max = l_max + self.include_inversion = include_inversion if l_max > 0: - lebedev_order, n_inplane_rotations = choose_quadrature(l_max) - self.o3_quadrature_rotations, self.o3_quadrature_weights = ( - _get_o3_quadrature(lebedev_order, n_inplane_rotations) + lebedev_order, n_inplane_rotations = _choose_quadrature(l_max) + self.quadrature_rotations, self.quadrature_weights = _get_quadrature( + lebedev_order, n_inplane_rotations, include_inversion ) else: # no quadrature - self.o3_quadrature_rotations = np.array([np.eye(3)]) - self.o3_quadrature_weights = np.array([1.0]) + self.quadrature_rotations = np.array([np.eye(3)]) + self.quadrature_weights = np.array([1.0]) self.batch_size = ( - batch_size if batch_size is not None else len(self.o3_quadrature_rotations) + batch_size if batch_size is not None else len(self.quadrature_rotations) ) - self.return_o3_samples = return_o3_samples + self.return_samples = return_samples self.apply_group_symmetry = apply_group_symmetry - def calculate(self, atoms, properties, system_changes): - super().calculate(atoms, properties, system_changes) + def calculate( + self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] + ) -> None: + """ + Perform the calculation for the given atoms and properties. - if self.apply_group_symmetry: - Q_list, P_list = _get_group_operations(atoms) + :param atoms: the :py:class:`ase.Atoms` on which to perform the calculation + :param properties: list of properties to compute, among ``energy``, ``forces``, + and ``stress`` + :param system_changes: list of changes to the system since the last call to + ``calculate`` + """ + super().calculate(atoms, properties, system_changes) compute_forces_and_stresses = "forces" in properties or "stress" in properties - if len(self.o3_quadrature_rotations) > 0: - rotated_atoms_list = _rotate_atoms(atoms, self.o3_quadrature_rotations) + if len(self.quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) batches = [ rotated_atoms_list[i : i + self.batch_size] for i in range(0, len(rotated_atoms_list), self.batch_size) @@ -1040,29 +988,32 @@ def calculate(self, atoms, properties, system_changes): "averaging parameters. This can be done by setting the " "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " "parameters while initializing the calculator." - f"Full error message: {e}" - ) - - # Clean up - try: - torch.cuda.empty_cache() - except Exception: - pass + ) from e self.results.update( _compute_rotational_average( - results, self.o3_quadrature_rotations, self.o3_quadrature_weights + results, self.quadrature_rotations, self.quadrature_weights ) ) - if self.return_o3_samples: - self.results["o3_samples"] = results + if self.return_samples: + sample_names = "o3_samples" if self.include_inversion else "so3_samples" + self.results[sample_names] = results if self.apply_group_symmetry: + # Apply the discrete space group of the system a posteriori + Q_list, P_list = _get_group_operations(atoms) self.results.update(_average_over_group(self.results, Q_list, P_list)) -def choose_quadrature(L_max): +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ available = [ 3, 5, @@ -1105,6 +1056,13 @@ def choose_quadrature(L_max): def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + """ + Create a list of copies of ``atoms``, rotated by each of the given ``rotations``. + + :param atoms: the :py:class:`ase.Atoms` to be rotated + :param rotations: (N, 3, 3) array of orthogonal matrices + :return: list of N :py:class:`ase.Atoms`, each rotated by the corresponding matrix + """ rotated_atoms_list = [] has_cell = atoms.cell is not None and atoms.cell.rank > 0 for rot in rotations: @@ -1119,9 +1077,17 @@ def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Ato return rotated_atoms_list -def _get_so3_quadrature(lebedev_order: int, n_rotations: int): +def _get_quadrature(lebedev_order: int, n_rotations: int, include_inversion: bool): """ Lebedev(S^2) x uniform angle quadrature on SO(3). + If include_inversion=True, extend to O(3) by adding inversion * R. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :param include_inversion: if ``True``, include the inversion operation in the + quadrature + :return: (N, 3, 3) array of orthogonal matrices, and (N,) array of weights + associated to each matrix """ from scipy.integrate import lebedev_rule @@ -1138,42 +1104,12 @@ def _get_so3_quadrature(lebedev_order: int, n_rotations: int): Rot = _rotations_from_angles(alpha, beta, gamma) R_so3 = Rot.as_matrix() # (N, 3, 3) - w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) - - return R_so3, w_so3 - - -def _get_o3_quadrature(lebedev_order: int, n_rotations: int): - """ - Lebedev(S^2) x uniform angle quadrature on O(3). - Returns an array of shape (2N, 3, 3) with orthogonal matrices, - the first N in SO(3), the next N in its coset with inversion. - """ - from scipy.integrate import lebedev_rule - - # Lebedev nodes (X: (3, M)) - X, w = lebedev_rule(lebedev_order) # w sums to 4*pi - x, y, z = X - alpha = np.arctan2(y, x) # (M,) - beta = np.arccos(z) # (M,) - # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) - - K = int(n_rotations) - gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) - - Rot = _rotations_from_angles(alpha, beta, gamma) - R_so3 = Rot.as_matrix() # (N, 3, 3) - - # rnd = np.random.uniform(size=(3, 3)) - # rnd = rnd - rnd.T - # import scipy.linalg - - # rnd = scipy.linalg.expm(-rnd) - # R_so3 = R_so3 @ rnd - # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + if not include_inversion: + return R_so3, w_so3 + # Extend to O(3) by appending inversion * R P = -np.eye(3) R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) @@ -1251,12 +1187,19 @@ def _get_group_operations( atom-index permutations P_g (N x N) induced by the space-group operations. Returns Q_list, Cartesian rotation matrices of the point group, and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + + :param atoms: input structure + :param symprec: tolerance for symmetry finding + :param angle_tolerance: tolerance for symmetry finding (in degrees). If less than 0, + a value depending on ``symprec`` will be chosen automatically by spglib. + :return: List of rotation matrices and permutation matrices. + """ try: import spglib except ImportError as e: raise ImportError( - "spglib is required to use the O3AveragedCalculator with " + "spglib is required to use the SymmetrizedCalculator with " "`apply_group_symmetry=True`. Please install it with " "`pip install spglib` or `conda install -c conda-forge spglib`" ) from e @@ -1325,17 +1268,14 @@ def _average_over_group( """ Apply the point-group projector in output space. - Parameters - ---------- - results : dict - Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or 'stress' (3,3). - These are predictions for the *current* structure in the reference frame. - Q_list, P_list : outputs of _get_group_operations - - Returns - ------- - out : dict - Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. + :param results: Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or + 'stress' (3,3). These are predictions for the current structure in the reference + frame. + :param Q_list: Rotation matrices of the point group, from + :py:func:`_get_group_operations` + :param P_list: Permutation matrices of the point group, from + :py:func:`_get_group_operations` + :return out: Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. For stress, also returns 'stress_iso_pg' (L=0) and 'stress_dev_pg'. """ m = len(Q_list) @@ -1355,7 +1295,7 @@ def _average_over_group( return out out = {} - # Energy: unchanged by the projector (scalar invariant) + # Energy: unchanged by the projector (scalar) if "energy" in results: out["energy"] = float(results["energy"]) @@ -1380,9 +1320,5 @@ def _average_over_group( acc += Q.T @ S @ Q S_pg = acc / m out["stress"] = S_pg - # # Expose L=0 projection and deviatoric part for debugging - # S_iso = np.trace(S_pg) / 3.0 - # out["stress_iso_pg"] = np.eye(3) * S_iso - # out["stress_dev_pg"] = S_pg - out["stress_iso_pg"] return out diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py new file mode 100644 index 00000000..977b3a18 --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -0,0 +1,354 @@ +import numpy as np +import pytest +from ase import Atoms + +from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature + + +def _body_axis_from_atoms(atoms: Atoms) -> np.ndarray: + """ + Return the normalized vector connecting the two farthest atoms. + + :param atoms: Atomic configuration. + :return: Normalized 3D vector defining the body axis. + """ + pos = atoms.get_positions() + if len(pos) < 2: + return np.array([0.0, 0.0, 1.0]) + d2 = np.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) + i, j = np.unravel_index(np.argmax(d2), d2.shape) + b = pos[j] - pos[i] + nrm = np.linalg.norm(b) + return b / nrm if nrm > 0 else np.array([0.0, 0.0, 1.0]) + + +def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: + """ + Compute Legendre polynomials P0..P3(c). + + :param c: Cosine between the body axis and the lab z-axis. + :return: Tuple (P0, P1, P2, P3). + """ + P0 = 1.0 + P1 = c + P2 = 0.5 * (3 * c * c - 1.0) + P3 = 0.5 * (5 * c * c * c - 3 * c) + return P0, P1, P2, P3 + + +class MockAnisoCalculator: + """ + Deterministic, rotation-dependent mock for testing SymmetrizedCalculator. + + Components: + - Energy: E_true + a1*P1 + a2*P2 + a3*P3 + - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*ẑ + optional tensor L=2 term + - Stress: p_iso*I + (c2*P2 + c3*P3)*D + + :param a: Coefficients for Legendre P0..P3 in the energy. + :param b: Coefficients for P1..P3 in the forces (spurious vector parts). + :param c: Coefficients for P2,P3 in the stress (spurious deviators). + :param p_iso: Isotropic (true) part of the stress tensor. + :param tensor_forces: If True, add L=2 tensor-coupled force term. + :param tensor_amp: Amplitude of the tensor-coupled force component. + """ + + def __init__( + self, + a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: tuple[float, float, float] = (0.0, 0.0, 0.0), + c: tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, + ) -> None: + self.a0, self.a1, self.a2, self.a3 = a + self.b1, self.b2, self.b3 = b + self.c2, self.c3 = c + self.p_iso = p_iso + self.tensor_forces = tensor_forces + self.tensor_amp = tensor_amp + + def compute_energy( + self, + batch: list[Atoms], + compute_forces_and_stresses: bool = False, + ) -> dict[str, list[np.ndarray | float]]: + """ + Compute deterministic, rotation-dependent properties for each batch entry. + + :param batch: List of atomic configurations. + :param compute_forces_and_stresses: Unused flag for API compatibility. + :return: Dictionary with lists of energies, forces, and stresses. + """ + out: dict[str, list[np.ndarray | float]] = { + "energy": [], + "forces": [], + "stress": [], + } + zhat = np.array([0.0, 0.0, 1.0]) + D = np.diag([1.0, -1.0, 0.0]) + + for atoms in batch: + pos = atoms.get_positions() + b = _body_axis_from_atoms(atoms) + c = float(np.dot(b, zhat)) + P0, P1, P2, P3 = _legendre_0_1_2_3(c) + + # Energy + E_true = float(np.sum(pos**2)) + E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 + + # Forces + F_true = pos.copy() + F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * zhat[None, :] + F = F_true + F_spur + + if self.tensor_forces: + # Build rotation R such that R ẑ = b + v = np.cross(zhat, b) + s = np.linalg.norm(v) + cth = np.dot(zhat, b) + if s < 1e-15: + R = np.eye(3) if cth > 0 else -np.eye(3) + else: + vx = np.array( + [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]] + ) + R = np.eye(3) + vx + vx @ vx * ((1 - cth) / (s**2)) + T = R @ D @ R.T + F_tensor = self.tensor_amp * (T @ zhat) + F = F + F_tensor[None, :] + + # Stress + S = self.p_iso * np.eye(3) + (self.c2 * P2 + self.c3 * P3) * D + + out["energy"].append(E) + out["forces"].append(F) + out["stress"].append(S) + return out + + +@pytest.fixture +def dimer() -> Atoms: + """ + Create a small asymmetric geometry with a well-defined body axis. + + :return: ASE Atoms object with the H2 molecule. + """ + return Atoms("H2", positions=[[0, 0, 0], [0.3, 0.2, 1.0]]) + + +def test_quadrature_normalization() -> None: + """Verify normalization and determinant signs of the quadrature.""" + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=True) + assert np.isclose(np.sum(w), 1.0) + dets = np.linalg.det(R) + assert np.all(np.isin(np.round(dets).astype(int), [-1, 1])) + + +@pytest.mark.parametrize("Lmax, expect_removed", [(0, False), (3, True)]) +def test_energy_L_components_removed( + dimer: Atoms, Lmax: int, expect_removed: bool +) -> None: + """ + Verify that spurious energy components vanish once rotational averaging is applied. + For Lmax>0, all use the same minimal Lebedev rule (order=3). + """ + a = (1.0, 1.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + e = dimer.get_potential_energy() + E_true = float(np.sum(dimer.positions**2)) + if expect_removed: + assert np.isclose(e, E_true + a[0], atol=1e-10) + else: + assert not np.isclose(e, E_true + a[0], atol=1e-10) + + +def test_force_backrotation_exact(dimer: Atoms) -> None: + """ + Check that forces are back-rotated exactly when no spurious terms are present. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-12) + + +def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: + """ + Tensor-coupled (L=2) force components must vanish under O(3) averaging. + + Since the minimal Lebedev order used internally is 3, all quadratures + integrate L=2 components exactly; we only check for correct cancellation. + """ + base = MockAnisoCalculator(tensor_forces=True, tensor_amp=1.0) + + for Lmax in [1, 2, 3]: + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-10) + + +def test_stress_isotropization(dimer: Atoms) -> None: + """ + Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(c=(1.0, 1.0), p_iso=5.0) + calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) + dimer.calc = calc + S = dimer.get_stress(voigt=False) + iso = np.trace(S) / 3.0 + assert np.allclose(S, np.eye(3) * iso, atol=1e-10) + assert np.isclose(iso, 5.0, atol=1e-10) + + +def test_cancellation_vs_Lmax(dimer: Atoms) -> None: + """ + Residual anisotropy must vanish once rotational averaging is applied. + All quadratures with Lmax>0 are equivalent (Lebedev order=3). + """ + a = (0.0, 0.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + E_true = float(np.sum(dimer.positions**2)) + + # No averaging + calc0 = SymmetrizedCalculator(base, l_max=0) + dimer.calc = calc0 + e0 = dimer.get_potential_energy() + + # Averaged + calc3 = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc3 + e3 = dimer.get_potential_energy() + + assert not np.isclose(e0, E_true, atol=1e-10) + assert np.isclose(e3, E_true, atol=1e-10) + + +def test_joint_energy_force_consistency(dimer: Atoms) -> None: + """ + Combined test: both energy and forces are consistent and invariant. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(a=(1, 1, 1, 1), b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + e = dimer.get_potential_energy() + f = dimer.get_forces() + assert np.isclose(e, np.sum(dimer.positions**2) + 1.0, atol=1e-10) + assert np.allclose(f, dimer.positions, atol=1e-12) + + +def test_rotate_atoms_preserves_geometry(tmp_path): + """Check that _rotate_atoms applies rotations correctly and preserves distances.""" + from scipy.spatial.transform import Rotation + + from metatomic.torch.ase_calculator import _rotate_atoms + + # Build simple cubic cell with 2 atoms along x + atoms = Atoms("H2", positions=[[0, 0, 0], [1, 0, 0]], cell=np.eye(3)) + R = Rotation.from_euler("z", 90, degrees=True).as_matrix()[None, ...] # 90° about z + + rotated = _rotate_atoms(atoms, R)[0] + # Positions should now align along y + assert np.allclose( + rotated.positions[1] - rotated.positions[0], [0, 1, 0], atol=1e-12 + ) + # Cell rotated + assert np.allclose(rotated.cell[0], [0, 1, 0], atol=1e-12) + # Distances preserved + d0 = atoms.get_distance(0, 1) + d1 = rotated.get_distance(0, 1) + assert np.isclose(d0, d1, atol=1e-12) + + +def test_choose_quadrature_rules(): + """Check that _choose_quadrature selects appropriate rules.""" + from metatomic.torch.ase_calculator import _choose_quadrature + + for L in [0, 5, 17, 50]: + lebedev_order, n_gamma = _choose_quadrature(L) + assert lebedev_order >= L + assert n_gamma == 2 * L + 1 + + +def test_get_quadrature_properties(): + """Check properties of the quadrature returned by _get_quadrature.""" + from metatomic.torch.ase_calculator import _get_quadrature + + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=False) + assert np.isclose(np.sum(w), 1.0) + assert np.allclose([np.dot(r.T, r) for r in R], np.eye(3), atol=1e-12) + assert np.allclose(np.linalg.det(R), 1.0, atol=1e-12) + + R_inv, w_inv = _get_quadrature( + lebedev_order=11, n_rotations=5, include_inversion=True + ) + assert len(R_inv) == 2 * len(R) + dets = np.linalg.det(R_inv) + assert np.all(np.isin(np.sign(dets).astype(int), [-1, 1])) + assert np.isclose(np.sum(w_inv), 1.0) + + +def test_compute_rotational_average_identity(): + """Check that _compute_rotational_average produces correct averages.""" + from metatomic.torch.ase_calculator import _compute_rotational_average + + R = np.repeat(np.eye(3)[None, :, :], 3, axis=0) + w = np.ones(3) / 3 + results = { + "energy": np.array([1.0, 2.0, 3.0]), + "forces": np.array([[[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]]), + "stress": np.array([np.eye(3), 2 * np.eye(3), 3 * np.eye(3)]), + } + out = _compute_rotational_average(results, R, w) + assert np.isclose(out["energy"], np.mean(results["energy"])) + assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) + assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) + + +def test_average_over_fcc_group(): + """ + Check that averaging over the space group of an FCC crystal + produces an isotropic (scalar) stress tensor. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # FCC conventional cubic cell (4 atoms) + a0 = 4.05 + atoms = Atoms( + "Cu4", + positions=[ + [0, 0, 0], + [0, 0.5, 0.5], + [0.5, 0, 0.5], + [0.5, 0.5, 0], + ], + cell=a0 * np.eye(3), + pbc=True, + ) + + # Create an intentionally anisotropic stress + stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) + results = {"stress": stress} + + Q_list, P_list = _get_group_operations(atoms) + out = _average_over_group(results, Q_list, P_list) + S_pg = out["stress"] + + # The averaged stress must be isotropic: S_pg = (trace/3)*I + iso = np.trace(S_pg) / 3.0 + assert np.allclose(S_pg, np.eye(3) * iso, atol=1e-8) From 6eec538516020736c6f1abcead10162cde7b3ed2 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:04:00 +0200 Subject: [PATCH 12/57] lint --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 7e886d0f..0840bae0 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -217,9 +217,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs @@ -1221,7 +1221,7 @@ def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): seen = set() Ainv = np.linalg.inv(A) - for Rf, tf in zip(R_frac, t_frac): + for Rf, tf in zip(R_frac, t_frac, strict=False): # Cartesian rotation: Q = A Rf A^{-1} Q = A @ Rf @ Ainv # Deduplicate rotations (point group) by rounding @@ -1286,7 +1286,7 @@ def _average_over_group( if F.ndim != 2 or F.shape[1] != 3: raise ValueError(f"'forces' must be (N,3), got {F.shape}") acc = np.zeros_like(F) - for Q, P in zip(Q_list, P_list): + for Q, P in zip(Q_list, P_list, strict=False): acc += P.T @ (F @ Q) out["forces"] = acc / m From a935bf4da0cf53490266350b1dadb1846cb1b3d9 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:04:12 +0200 Subject: [PATCH 13/57] add deps for testing --- tox.ini | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tox.ini b/tox.ini index 86a3b3e8..795382c7 100644 --- a/tox.ini +++ b/tox.ini @@ -150,6 +150,9 @@ deps = # for metatensor-lj-test setuptools-scm cmake + # for symmetrized calculator + scipy + spglib changedir = python/metatomic_torch commands = From bb96a6eebe06c249309ab6387d2784d516df4af4 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 9 Oct 2025 14:06:50 +0200 Subject: [PATCH 14/57] Add mention in the docs --- docs/src/engines/ase.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index da8bce2d..fed91cdd 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -23,6 +23,8 @@ Supported model outputs :py:meth:`ase.Atoms.get_forces`, …); - arbitrary outputs can be computed for any :py:class:`ase.Atoms` using :py:meth:`MetatomicCalculator.run_model`; +- for non-equivariant architectures like PET, rotatonally-averaged energies, forces, + and stresses can be computed using :py:class:`SymmetrizedCalculator`. How to install the code ^^^^^^^^^^^^^^^^^^^^^^^ From b0debd6c2e1adc9100088f3b40ee0965d9b46a37 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 10 Oct 2025 16:07:32 +0200 Subject: [PATCH 15/57] Start implementing O3Sampler, TokenProjector --- .../metatomic/torch/rotational_utils.py | 753 ++++++++++++++++++ 1 file changed, 753 insertions(+) create mode 100644 python/metatomic_torch/metatomic/torch/rotational_utils.py diff --git a/python/metatomic_torch/metatomic/torch/rotational_utils.py b/python/metatomic_torch/metatomic/torch/rotational_utils.py new file mode 100644 index 00000000..ddc88825 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/rotational_utils.py @@ -0,0 +1,753 @@ +""" +Utilities for diagnosing rotational equivariance of models and for enforcing +rotational symmetry in data augmentation and model evaluation. +""" + +import warnings +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import TensorMap +from metatrain.utils.augmentation import ( + _apply_random_augmentations, + _complex_to_real_spherical_harmonics_transform, + _scipy_quaternion_to_quaternionic, +) + +from metatomic.torch import ModelEvaluationOptions, System, register_autograd_neighbors +from metatomic.torch.model import AtomisticModel + + +try: + from scipy.spatial.transform import Rotation +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for Lebedev quadrature. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + from scipy.integrate import lebedev_rule + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (n_rotations,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (N,) + + return alpha, beta, gamma, w_so3 + + +def _rotations_from_angles(alpha, beta, gamma): + from scipy.spatial.transform import Rotation + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def evaluate_model_on_quadrature(model, systems, L_max: int, device="cpu"): + pass + + +############ + + +def _extract_euler_zyz( + R: torch.Tensor, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + R01 = R_flat[:, 0, 1] + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + R11 = R_flat[:, 1, 1] + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = torch.clip(R22, -1.0, 1.0) + betas = torch.arccos(zz) + + # For Z-Y-Z, standard formulas away from the singular set + alphas = torch.arctan2(R12, R02) + gammas = torch.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * torch.pi + alphas = torch.remainder(alphas, two_pi) + gammas = torch.remainder(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = torch.sin(betas) + near = torch.abs(sinb) < eps + if torch.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if torch.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from 2x2 block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = torch.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = torch.remainder(alphas[near_zero], two_pi) + + if torch.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on R00. + betas[near_pi] = torch.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = torch.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = torch.remainder(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def get_so3_character( + alphas: torch.Tensor, + betas: torch.Tensor, + gammas: torch.Tensor, + o3_lambda: int, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_ℓ(ω) = sin((2ℓ+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = torch.cos(betas / 2.0) * torch.cos((alphas + gammas) / 2.0) + cos_t = torch.clip(cos_t, -1.0, 1.0) + t = torch.arccos(cos_t) + + # Output array + chi = torch.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = torch.abs(t) < tol + if torch.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if torch.any(large): + tl = t[large] + sin_t = torch.sin(tl) + numer = torch.sin(a * tl) + mask = torch.abs(sin_t) >= tol + out = torch.empty_like(tl) + torch.div(numer, sin_t, out=out) # TODO figure out with numpy divide + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def get_so3_characters_dict( + alphas: torch.Tensor, betas: torch.Tensor, gammas: torch.Tensor, o3_lambda_max: int +) -> Dict[int, torch.Tensor]: + """ + Returns a dictionary of the SO(3) characters for all o3_lambda in [0, o3_lambda_max]. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + characters[o3_lambda] = get_so3_character(alphas, betas, gammas, o3_lambda) + return characters + + +def get_pso3_characters_dict( + so3_character: Dict[int, torch.Tensor], o3_lambda_max: int +) -> Dict[Tuple[int, int], torch.Tensor]: + """ + Returns a dictionary of the P⋅SO(3) characters for all (o3_lambda, o3_sigma) pairs + with o3_lambda in [0, o3_lambda_max] and o3_sigma in {-1, +1}. + Requires a pre-computed dictionary of SO(3) characters. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + characters[(o3_lambda, o3_sigma)] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_character[o3_lambda] + ) + return characters + + +############ + + +class O3Sampler: + """ + Compute model predictions on a quadrature over the O(3) group. + + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + """ + + def __init__(self, quad_l_max: int, project_l_max: int, batch_size: int = 1): + try: + from scipy.spatial.transform import Rotation + except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + self.quad_l_max = quad_l_max + """Maximum spherical harmonic degree for quadrature.""" + + self.project_l_max = project_l_max + """Maximum spherical harmonic degree to project onto.""" + if self.project_l_max + 2 > self.quad_l_max: + warnings.warn( + ( + f"Projecting up to l={self.project_l_max} with quadrature up " + f"to l={self.quad_l_max} may be inaccurate." + ), + stacklevel=2, + ) + + # Get the quadrature + self.lebedev_order: int + """Number of Lebedev nodes on the unit sphere.""" + + self.n_inplane_rotations: int + """Number of in-plane rotations per Lebedev node.""" + self.lebedev_order, self.n_inplane_rotations = _choose_quadrature( + self.quad_l_max + ) + + self.w_so3: torch.Tensor + """Weights associated to each rotation in the SO(3) Haar measure.""" + + alpha, beta, gamma, self.w_so3 = get_euler_angles_quadrature( + self.lebedev_order, self.n_inplane_rotations + ) + self.w_so3 = torch.from_numpy(self.w_so3) + + # For active rotation of systems + self.R_so3 = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ) + """Rotation matrices.""" + + self.n_rotations = self.R_so3.size(0) + + # For inverse rotation of tensors + R_pso3 = _rotations_from_angles(np.pi - alpha, beta, np.pi - gamma) + self.wigner_D: Dict[int, torch.Tensor] = _compute_wigner_D_matrices( + self.project_l_max, R_pso3 + ) + """Dict mapping l to (N, (2l+1), (2l+1)) torch.Tensor of Wigner D matrices.""" + + self.R_pso3 = torch.from_numpy(R_pso3.as_matrix()) + """Inverse rotation matrices.""" + + self.batch_size = batch_size + + _external_product_euler_angles = _extract_euler_zyz( + (self.R_so3[:, None, :, :] @ self.R_pso3[None, :, :, :]), eps=1e-6 + ) + self.so3_characters = get_so3_characters_dict( + *_external_product_euler_angles, self.project_l_max + ) + self.pso3_characters = get_pso3_characters_dict( + self.so3_characters, self.project_l_max + ) + + def evaluate( + self, + model: AtomisticModel, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ): + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + backtransformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs = [] + for batch in range(0, len(self.R_so3), self.batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) + ) + for R in self.R_so3[batch : batch + self.batch_size] + ] + outputs = model( + transformed_systems, + options=options, + check_consistency=check_consistency, + ) + rotation_outputs.append(outputs) + + for name in transformed_outputs: + tensor = mts.join( + [r[name] for r in rotation_outputs], + "samples", + remove_tensor_name=True, + ) + transformed_outputs[name][i_sys][inversion] = mts.rename_dimension( + tensor, "samples", "tensor", "o3_sample" + ) + + n_rot = self.R_so3.size(0) + for name in transformed_outputs: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = transformed_outputs[name][i_sys][inversion] + _, backtransformed, _ = _apply_random_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + self.R_pso3.to(device=device, dtype=dtype) * inversion + ).unbind(0) + ), + self.wigner_D, + ) + backtransformed_outputs[name][i_sys][inversion] = backtransformed[ + name + ] + + return transformed_outputs, backtransformed_outputs + + +class TokenProjector(torch.nn.Module): + """ + Wrap an atomistic model to project its predictions onto spherical sectors. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ) -> None: + super().__init__("TokenProjector") + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: TODO + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + systems, self.model, options, check_consistency + ) + + # TODO do projection operations + pass + + +class SymmetrizedAtomisticModel(torch.nn.Module): + """ + Wrap an atomistic model to symmetrize its predictions over a quadrature and compute + O(3) averages, variances, and equivariance score. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ): + super().__init__("SymmetrizedAtomisticModel") + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + systems, self.model, options, check_consistency + ) + + return compute_projections( + self.o3_sampler.project_l_max, + systems, + transformed_outputs, + self.o3_sampler.w_so3, + self.o3_sampler.so3_characters, + self.o3_sampler.pso3_characters, + ) + + +def _compute_wigner_D_matrices( + l_max: int, + rotations: List["Rotation"], + complex_to_real: Optional[np.ndarray] = None, +) -> dict: + """ + Compute Wigner D matrices for all l <= project_l_max. + + :param l_max: maximum spherical harmonic degree + :param rotations: list of scipy Rotation objects + :param complex_to_real: optional dict mapping l to (2l+1, (2l+1)) array to convert + complex spherical harmonics to real spherical harmonics + :return: dict mapping l to (N, (2l+1), (2l+1)) array of Wigner D matrices + """ + + try: + import spherical + except ImportError as e: + # quaternionic (used below) is a dependency of spherical + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e + + wigner = spherical.Wigner(l_max) + scipy_quaternions = [r.as_quat() for r in rotations] + quaternionic_quaternions = [ + _scipy_quaternion_to_quaternionic(q) for q in scipy_quaternions + ] + wigner_D_matrices_complex = [wigner.D(q) for q in quaternionic_quaternions] + + if complex_to_real is None: + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(l_max + 1) + } + + wigner_D_matrices = {} + for ell in range(l_max + 1): + U = complex_to_real[ell] + wigner_D_matrices_l = [] + for wigner_D_matrix_complex in wigner_D_matrices_complex: + wigner_D_matrix = np.zeros((2 * ell + 1, 2 * ell + 1), dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + wigner_D_matrix[m + ell, mp + ell] = ( + wigner_D_matrix_complex[wigner.Dindex(ell, m, mp)] + ).conj() + + wigner_D_matrix = U.conj() @ wigner_D_matrix @ U.T + assert np.allclose(wigner_D_matrix.imag, 0.0) + wigner_D_matrix = wigner_D_matrix.real + wigner_D_matrices_l.append(torch.from_numpy(wigner_D_matrix)) + wigner_D_matrices[ell] = wigner_D_matrices_l + + return wigner_D_matrices + + +# O3-integrals utilities + + +def compute_projections( + max_l: int, + systems: List[System], + transformed_outputs: Dict[str, List[TensorMap]], + weights: torch.Tensor, + so3_characters: Dict[int, torch.Tensor], + pso3_characters: Dict[Tuple[int, int], torch.Tensor], +) -> Tuple[ + Dict[str, List[Dict[int, TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], +]: + """ + + TODO docstring, check type annotations + + - Take model outputs on a quadrature + - Manipulate dimensions + - Compute some integrals + - Return projections + + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + weights = weights.to(device, dtype) + so3_characters = {k: v.to(device, dtype) for k, v in so3_characters.items()} + pso3_characters = {k: v.to(device, dtype) for k, v in pso3_characters.items()} + + n_rotations = len(weights) + norms = {} + convolution_integrals = {} + normalized_convolution_integrals = {} + # Loop over targets + for name, transformed_output in transformed_outputs.items(): + norms[name] = [] + convolution_integrals[name] = [] + normalized_convolution_integrals[name] = [] + for o3_output_for_system in transformed_output: + proper = o3_output_for_system[1] + improper = o3_output_for_system[-1] + + # Weighting the tensors + broadcasted_w = ( + weights[proper[0].samples.column("o3_sample")] / 16 / torch.pi**2 + ) + proper_weighted = proper.copy() + improper_weighted = improper.copy() + for k in proper_weighted.keys: + proper_block = proper_weighted[k] + improper_block = improper_weighted[k] + proper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (proper_block.values.ndim - 1) + ) + improper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (improper_block.values.ndim - 1) + ) + + # Compute norms + proper_norm = mts.multiply(proper, proper_weighted) + improper_norm = mts.multiply(improper, improper_weighted) + norm = mts.add(proper_norm, improper_norm) + norm = mts.sum_over_samples(norm, "o3_sample") + norms[name].append(norm) + + # Compute convolution integrals + convolution_integral = {} + normalized_convolution_integral = {} + for ell in range(max_l + 1): + so3_char = so3_characters[ell] + for sigma in [-1, 1]: + pso3_char = pso3_characters[(ell, sigma)] + + integral_blocks = [] + for k in proper.keys: + proper_block = proper[k].values.reshape( + -1, n_rotations, *proper[k].shape[1:] + ) + improper_block = improper[k].values.reshape( + -1, n_rotations, *improper[k].shape[1:] + ) + integral_values = ( + ( + 0.25 + * torch.einsum( + "ij...,nij...->n...", + so3_char, + proper_block[:, :, None, ...] + * proper_block[:, None, :, ...] + + improper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + + 0.5 + * torch.einsum( + "ij...,nij...->n...", + pso3_char, + proper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + ) + * (2 * ell + 1) + / (8 * torch.pi**2) ** 2 + ) + integral_blocks.append( + mts.TensorBlock( + samples=norm[k].samples, + components=norm[k].components, + properties=norm[k].properties, + values=integral_values, + ) + ) + convolution_integral[(ell, sigma)] = mts.TensorMap( + keys=norm.keys, blocks=integral_blocks + ) + normalized_convolution_integral[(ell, sigma)] = mts.divide( + convolution_integral[(ell, sigma)], norm + ) + convolution_integrals[name].append(convolution_integral) + normalized_convolution_integrals[name].append( + normalized_convolution_integral + ) + + return norms, convolution_integrals, normalized_convolution_integrals From d6065084d795b98fc9b63a046689da3ef9b94aba Mon Sep 17 00:00:00 2001 From: Joseph Abbott Date: Sun, 12 Oct 2025 15:37:58 +0200 Subject: [PATCH 16/57] Importable version --- .../metatomic/torch/__init__.py | 5 ++ .../metatomic/torch/rotational_utils.py | 61 +++++++++++++++---- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index de2116ca..07683c71 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -64,5 +64,10 @@ def __getattr__(name): import metatomic.torch.ase_calculator return metatomic.torch.ase_calculator + + elif name == "rotational_utils": + import metatomic.torch.rotational_utils + + return metatomic.torch.rotational_utils else: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/metatomic_torch/metatomic/torch/rotational_utils.py b/python/metatomic_torch/metatomic/torch/rotational_utils.py index ddc88825..a2a56132 100644 --- a/python/metatomic_torch/metatomic/torch/rotational_utils.py +++ b/python/metatomic_torch/metatomic/torch/rotational_utils.py @@ -11,17 +11,18 @@ import torch from metatensor.torch import TensorMap from metatrain.utils.augmentation import ( - _apply_random_augmentations, + _apply_augmentations, _complex_to_real_spherical_harmonics_transform, _scipy_quaternion_to_quaternionic, ) +import metatomic.torch # noqa: F401 from metatomic.torch import ModelEvaluationOptions, System, register_autograd_neighbors from metatomic.torch.model import AtomisticModel try: - from scipy.spatial.transform import Rotation + from scipy.spatial.transform import Rotation # noqa: F401 except ImportError as e: raise ImportError( "To perform data augmentation on spherical targets, please " @@ -103,8 +104,6 @@ def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): def _rotations_from_angles(alpha, beta, gamma): - from scipy.spatial.transform import Rotation - # Build all combinations (alpha_i, beta_i, gamma_j) A = np.repeat(alpha, gamma.size) # (N,) B = np.repeat(beta, gamma.size) # (N,) @@ -181,10 +180,10 @@ def _extract_euler_zyz( # Read commonly-used entries with explicit names for readability R00 = R_flat[:, 0, 0] - R01 = R_flat[:, 0, 1] + # R01 = R_flat[:, 0, 1] # unused R02 = R_flat[:, 0, 2] R10 = R_flat[:, 1, 0] - R11 = R_flat[:, 1, 1] + # R11 = R_flat[:, 1, 1] # unused R12 = R_flat[:, 1, 2] R20 = R_flat[:, 2, 0] R21 = R_flat[:, 2, 1] @@ -212,14 +211,16 @@ def _extract_euler_zyz( near_pi = near & (zz < 0) # beta≈pi if torch.any(near_zero): - # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from 2x2 block. + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from 2x2 + # block. betas[near_zero] = 0.0 gammas[near_zero] = 0.0 alphas[near_zero] = torch.arctan2(R10[near_zero], R00[near_zero]) alphas[near_zero] = torch.remainder(alphas[near_zero], two_pi) if torch.any(near_pi): - # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on R00. + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. betas[near_pi] = torch.pi alphas[near_pi] = 0.0 gammas[near_pi] = torch.arctan2(R10[near_pi], -R00[near_pi]) @@ -287,7 +288,8 @@ def get_so3_characters_dict( alphas: torch.Tensor, betas: torch.Tensor, gammas: torch.Tensor, o3_lambda_max: int ) -> Dict[int, torch.Tensor]: """ - Returns a dictionary of the SO(3) characters for all o3_lambda in [0, o3_lambda_max]. + Returns a dictionary of the SO(3) characters for all o3_lambda in [0, + o3_lambda_max]. """ characters = {} for o3_lambda in range(o3_lambda_max + 1): @@ -321,11 +323,12 @@ class O3Sampler: :param quad_l_max: maximum spherical harmonic degree for quadrature :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch. """ def __init__(self, quad_l_max: int, project_l_max: int, batch_size: int = 1): try: - from scipy.spatial.transform import Rotation + from scipy.spatial.transform import Rotation # noqa: F401 except ImportError as e: raise ImportError( "To perform data augmentation on spherical targets, please " @@ -454,7 +457,7 @@ def evaluate( for i_sys, system in enumerate(systems): for inversion in [-1, 1]: tensor = transformed_outputs[name][i_sys][inversion] - _, backtransformed, _ = _apply_random_augmentations( + _, backtransformed, _ = _apply_augmentations( [system] * n_rot, {name: tensor}, list( @@ -488,7 +491,7 @@ def __init__( project_l_max: int, batch_size: Optional[int] = None, ) -> None: - super().__init__("TokenProjector") + super().__init__() self.model = model """The underlying atomistic model.""" self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) @@ -508,7 +511,7 @@ def forward( """ transformed_outputs, _ = self.o3_sampler.evaluate( - systems, self.model, options, check_consistency + self.model, systems, options, check_consistency ) # TODO do projection operations @@ -751,3 +754,35 @@ def compute_projections( ) return norms, convolution_integrals, normalized_convolution_integrals + + +# IO utilities + + +def norms_to(norms, dtype, device): + """Moves the TensorMap of norms to dtype and device""" + + norms_to = {} + for output_name in norms.keys(): + quantity_list = [] + for quantity in norms[output_name]: + quantity_list.append(quantity.to(dtype=dtype, device=device)) + norms_to[output_name] = quantity_list + + return norms_to + + +def integrals_to(integral, dtype, device): + """Moves the TensorMap of integrals to dtype and device""" + + integral_to = {} + for output_name in integral.keys(): + quantity_list = [] + for quantity_dict in integral[output_name]: + quantity_dict_to = {} + for key, quantity in quantity_dict.items(): + quantity_dict_to[key] = quantity.to(dtype=dtype, device=device) + quantity_list.append(quantity_dict_to) + integral_to[output_name] = quantity_list + + return integral_to From fe468ad09d614254ed4758f12cd44622a103d87a Mon Sep 17 00:00:00 2001 From: Paolo Pegolo Date: Wed, 29 Oct 2025 11:55:24 +0100 Subject: [PATCH 17/57] Update python/metatomic_torch/metatomic/torch/ase_calculator.py Co-authored-by: Guillaume Fraux --- python/metatomic_torch/metatomic/torch/ase_calculator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 0840bae0..d3466a74 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -887,6 +887,7 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): def __init__( self, base_calculator: MetatomicCalculator, + *, l_max: int = 3, batch_size: Optional[int] = None, include_inversion: bool = True, From 61217146f54c237646cd24a7bdba89b59358b6d2 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 12:01:48 +0100 Subject: [PATCH 18/57] Update docstrings --- .../metatomic/torch/ase_calculator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index d3466a74..838c77eb 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -873,7 +873,7 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): systems will be evaluated at once (this can lead to high memory usage). :param include_inversion: if ``True``, the inversion operation will be included in the averaging. This is required to average over the full orthogonal group O(3). - :param apply_group_symmetry: if ``True``, the results will be averaged over the + :param apply_space_group_symmetry: if ``True``, the results will be averaged over discrete space group of rotations for the input system. The group operations are computed with spglib, and the average is performed after the O(3) averaging (if any). @@ -891,7 +891,7 @@ def __init__( l_max: int = 3, batch_size: Optional[int] = None, include_inversion: bool = True, - apply_group_symmetry: bool = False, + apply_space_group_symmetry: bool = False, return_samples: bool = False, **kwargs: Any, ) -> None: @@ -928,7 +928,7 @@ def __init__( ) self.return_samples = return_samples - self.apply_group_symmetry = apply_group_symmetry + self.apply_space_group_symmetry = apply_space_group_symmetry def calculate( self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] @@ -968,8 +968,8 @@ def calculate( "Out of memory error encountered during rotational averaging. " "Please reduce the batch size or use lower rotational " "averaging parameters. This can be done by setting the " - "`batch_size`, `lebedev_order`, and `n_inplane_rotations` " - "parameters while initializing the calculator." + "`batch_size` and `l_max` parameters while initializing the " + "calculator." ) from e self.results.update( @@ -982,7 +982,7 @@ def calculate( sample_names = "o3_samples" if self.include_inversion else "so3_samples" self.results[sample_names] = results - if self.apply_group_symmetry: + if self.apply_space_group_symmetry: # Apply the discrete space group of the system a posteriori Q_list, P_list = _get_group_operations(atoms) self.results.update(_average_over_group(self.results, Q_list, P_list)) From b2bd3d2776e082e3b65afc8db33df0e5ae351429 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 12:30:40 +0100 Subject: [PATCH 19/57] Update docs --- docs/src/engines/ase.rst | 6 ++++-- docs/src/torch/reference/ase.rst | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index fed91cdd..8de8b324 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -23,8 +23,10 @@ Supported model outputs :py:meth:`ase.Atoms.get_forces`, …); - arbitrary outputs can be computed for any :py:class:`ase.Atoms` using :py:meth:`MetatomicCalculator.run_model`; -- for non-equivariant architectures like PET, rotatonally-averaged energies, forces, - and stresses can be computed using :py:class:`SymmetrizedCalculator`. +- for non-equivariant architectures like + `PET `_, + rotatonally-averaged energies, forces, and stresses can be computed using + :py:class:`metatomic.torch.ase_calculator.SymmetrizedCalculator`. How to install the code ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/torch/reference/ase.rst b/docs/src/torch/reference/ase.rst index f217a3b9..ddb1f49a 100644 --- a/docs/src/torch/reference/ase.rst +++ b/docs/src/torch/reference/ase.rst @@ -17,3 +17,7 @@ not just the energy, through the .. autoclass:: metatomic.torch.ase_calculator.MetatomicCalculator :show-inheritance: :members: + +.. autoclass:: metatomic.torch.ase_calculator.SymmetrizedCalculator + :show-inheritance: + :members: From a45fdb7a189490edb3d902605cb256a82be79e0e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 12:31:14 +0100 Subject: [PATCH 20/57] Implement Guillaume's suggestions --- .../metatomic/torch/ase_calculator.py | 50 +++++++------------ 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 838c77eb..f6bf5894 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -877,8 +877,8 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): discrete space group of rotations for the input system. The group operations are computed with spglib, and the average is performed after the O(3) averaging (if any). - :param return_samples: if ``True``, the results of the base calculator on each - rotated system will be returned. Most useful for debugging. + :param store_rotational_std: if ``True``, the results will contain the standard + deviation over the different rotations for each property (e.g., ``energy_std``). :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor """ @@ -892,7 +892,7 @@ def __init__( batch_size: Optional[int] = None, include_inversion: bool = True, apply_space_group_symmetry: bool = False, - return_samples: bool = False, + store_rotational_std: bool = False, **kwargs: Any, ) -> None: try: @@ -927,7 +927,7 @@ def __init__( batch_size if batch_size is not None else len(self.quadrature_rotations) ) - self.return_samples = return_samples + self.store_rotational_std = store_rotational_std self.apply_space_group_symmetry = apply_space_group_symmetry def calculate( @@ -974,14 +974,13 @@ def calculate( self.results.update( _compute_rotational_average( - results, self.quadrature_rotations, self.quadrature_weights + results, + self.quadrature_rotations, + self.quadrature_weights, + self.store_rotational_std, ) ) - if self.return_samples: - sample_names = "o3_samples" if self.include_inversion else "so3_samples" - self.results[sample_names] = results - if self.apply_space_group_symmetry: # Apply the discrete space group of the system a posteriori Q_list, P_list = _get_group_operations(atoms) @@ -1118,7 +1117,7 @@ def _rotations_from_angles(alpha, beta, gamma): return Rot -def _compute_rotational_average(results, rotations, weights): +def _compute_rotational_average(results, rotations, weights, store_std): R = rotations B = R.shape[0] w = weights @@ -1140,23 +1139,25 @@ def _wstd(x): if "energy" in results: E = np.asarray(results["energy"], dtype=float) out["energy"] = _wmean(E) - out["energy_rot_std"] = _wstd(E) + if store_std: + out["energy_rot_std"] = _wstd(E) # Forces (B,N,3) from rotated structures: back-rotate with F' R if "forces" in results: F = np.asarray(results["forces"], dtype=float) # (B,N,3) - F_back = np.einsum("bnj,bjk->bnk", F, R, optimize=True) # F' R + F_back = F @ R # F' R out["forces"] = _wmean(F_back) # (N,3) - out["forces_rot_std"] = _wstd(F_back) # (N,3) + if store_std: + out["forces_rot_std"] = _wstd(F_back) # (N,3) # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R if "stress" in results: S = np.asarray(results["stress"], dtype=float) # (B,3,3) RT = np.swapaxes(R, 1, 2) - tmp = np.einsum("bij,bjk->bik", RT, S, optimize=True) - S_back = np.einsum("bik,bkl->bil", tmp, R, optimize=True) # R^T S' R + S_back = RT @ S @ R # R^T S' R out["stress"] = _wmean(S_back) # (3,3) - out["stress_rot_std"] = _wstd(S_back) # (3,3) + if store_std: + out["stress_rot_std"] = _wstd(S_back) # (3,3) return out @@ -1257,24 +1258,11 @@ def _average_over_group( :py:func:`_get_group_operations` :param P_list: Permutation matrices of the point group, from :py:func:`_get_group_operations` - :return out: Projected quantities with keys: 'energy_pg', 'forces_pg', 'stress_pg'. - For stress, also returns 'stress_iso_pg' (L=0) and 'stress_dev_pg'. + :return out: Projected quantities. """ m = len(Q_list) if m == 0: - # No symmetry found; return copies - out = {} - if "energy" in results: - out["energy_pg"] = float(results["energy"]) - if "forces" in results: - out["forces_pg"] = np.array(results["forces"], float, copy=True) - if "stress" in results: - S = np.array(results["stress"], float, copy=True) - # S = 0.5 * (S + S.T) - out["stress_pg"] = S - out["stress_iso_pg"] = np.eye(3) * (np.trace(S) / 3.0) - out["stress_dev_pg"] = S - out["stress_iso_pg"] - return out + return results # nothing to do out = {} # Energy: unchanged by the projector (scalar) From cfa6c5514951cb1aa82e11d434049085291c1f95 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 12:52:18 +0100 Subject: [PATCH 21/57] Add to the docs that `apply_space_group_symmetry` has no effect for non-periodic systems --- python/metatomic_torch/metatomic/torch/ase_calculator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index f6bf5894..0b917f69 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -876,7 +876,7 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): :param apply_space_group_symmetry: if ``True``, the results will be averaged over discrete space group of rotations for the input system. The group operations are computed with spglib, and the average is performed after the O(3) averaging - (if any). + (if any). This has no effect for non-periodic systems. :param store_rotational_std: if ``True``, the results will contain the standard deviation over the different rotations for each property (e.g., ``energy_std``). :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor @@ -1198,6 +1198,10 @@ def _get_group_operations( symprec=symprec, angle_tolerance=angle_tolerance, ) + + if data is None: + # No symmetry found + return [], [] R_frac = data.rotations # (n_ops, 3,3), integer t_frac = data.translations # (n_ops, 3) Z = numbers From 24788fcccac3b6cc7da4c80ce95090c84243b5eb Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 13:08:26 +0100 Subject: [PATCH 22/57] fix tests --- python/metatomic_torch/tests/symmetrized_ase_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 977b3a18..f33d101e 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -311,7 +311,7 @@ def test_compute_rotational_average_identity(): "forces": np.array([[[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]]), "stress": np.array([np.eye(3), 2 * np.eye(3), 3 * np.eye(3)]), } - out = _compute_rotational_average(results, R, w) + out = _compute_rotational_average(results, R, w, False) assert np.isclose(out["energy"], np.mean(results["energy"])) assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) From c84c297f63aeea60a682b41d0c48543e61418d6f Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 13:29:18 +0100 Subject: [PATCH 23/57] update tests to increase coverage --- .../tests/symmetrized_ase_calculator.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index f33d101e..c5554a79 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -316,6 +316,11 @@ def test_compute_rotational_average_identity(): assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) + out = _compute_rotational_average(results, R, w, True) + assert "energy_rot_std" in out + assert "forces_rot_std" in out + assert "stress_rot_std" in out + def test_average_over_fcc_group(): """ @@ -341,14 +346,69 @@ def test_average_over_fcc_group(): pbc=True, ) + energy = 0.0 + forces = np.random.randn((4, 3)) + # Create an intentionally anisotropic stress stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) - results = {"stress": stress} + results = {"energy": energy, "forces": forces, "stress": stress} Q_list, P_list = _get_group_operations(atoms) out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must average to zero by symmetry + F_pg = out["forces"] + assert np.allclose(F_pg, np.zeros_like(F_pg)) + S_pg = out["stress"] # The averaged stress must be isotropic: S_pg = (trace/3)*I iso = np.trace(S_pg) / 3.0 assert np.allclose(S_pg, np.eye(3) * iso, atol=1e-8) + + +def test_space_group_average_non_periodic(): + """ + Check that averaging over the space group of a non-periodic system leaves the + results unchanged. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # Methane molecule (Td symmetry) + atoms = Atoms( + "CH4", + positions=[ + [0.000000, 0.000000, 0.000000], + [0.000000, 0.000000, 1.561000], + [0.000000, 1.561000, 0.000000], + [0.000000, 0.000000, -1.561000], + [0.000000, -1.561000, 0.000000], + ], + pbc=False, + ) + + energy = 0.0 + forces = np.random.randn((4, 3)) + + results = {"energy": energy, "forces": forces} + + Q_list, P_list = _get_group_operations(atoms) + + # Check that the operation lists are empty + assert len(Q_list) == 0 + assert len(P_list) == 0 + + out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must be unchanged + F_pg = out["forces"] + assert np.allclose(F_pg, forces) From 688d980c7dd78c8e5eb3377f2cbaa358c1873d1e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 29 Oct 2025 13:51:42 +0100 Subject: [PATCH 24/57] Fix tests --- .../tests/symmetrized_ase_calculator.py | 32 ++++--------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index c5554a79..f3e4884d 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -1,6 +1,7 @@ import numpy as np import pytest from ase import Atoms +from ase.build import bulk, molecule from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature @@ -333,21 +334,11 @@ def test_average_over_fcc_group(): ) # FCC conventional cubic cell (4 atoms) - a0 = 4.05 - atoms = Atoms( - "Cu4", - positions=[ - [0, 0, 0], - [0, 0.5, 0.5], - [0.5, 0, 0.5], - [0.5, 0.5, 0], - ], - cell=a0 * np.eye(3), - pbc=True, - ) + atoms = bulk("Cu", "fcc", cubic=True) energy = 0.0 - forces = np.random.randn((4, 3)) + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force # Create an intentionally anisotropic stress stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) @@ -381,20 +372,11 @@ def test_space_group_average_non_periodic(): ) # Methane molecule (Td symmetry) - atoms = Atoms( - "CH4", - positions=[ - [0.000000, 0.000000, 0.000000], - [0.000000, 0.000000, 1.561000], - [0.000000, 1.561000, 0.000000], - [0.000000, 0.000000, -1.561000], - [0.000000, -1.561000, 0.000000], - ], - pbc=False, - ) + atoms = molecule("CH4") energy = 0.0 - forces = np.random.randn((4, 3)) + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force results = {"energy": energy, "forces": forces} From 5bb64d4482ad5b178af252e3b3eb293c8a38ba8e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 5 Nov 2025 20:52:51 +0100 Subject: [PATCH 25/57] Various updates --- .../metatomic/torch/ase_calculator.py | 15 +- .../tests/symmetrized_ase_calculator.py | 353 ++++++++++++++---- 2 files changed, 284 insertions(+), 84 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 0b917f69..a0ceff34 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -854,7 +854,7 @@ def _full_3x3_to_voigt_6_stress(stress): class SymmetrizedCalculator(ase.calculators.calculator.Calculator): r""" Take a MetatomicCalculator and average its predictions to make it (approximately) - equivariant. + equivariant. Only predictions for energy, forces and stress are supported. The default is to average over a quadrature of the orthogonal group O(3) composed this way: @@ -875,11 +875,11 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): the averaging. This is required to average over the full orthogonal group O(3). :param apply_space_group_symmetry: if ``True``, the results will be averaged over discrete space group of rotations for the input system. The group operations are - computed with spglib, and the average is performed after the O(3) averaging - (if any). This has no effect for non-periodic systems. + computed with `spglib `, and the average is + performed after the O(3) averaging (if any). This has no effect for non-periodic + systems. :param store_rotational_std: if ``True``, the results will contain the standard deviation over the different rotations for each property (e.g., ``energy_std``). - :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor """ implemented_properties = ["energy", "forces", "stress"] @@ -893,17 +893,16 @@ def __init__( include_inversion: bool = True, apply_space_group_symmetry: bool = False, store_rotational_std: bool = False, - **kwargs: Any, ) -> None: try: from scipy.integrate import lebedev_rule # noqa: F401 except ImportError as e: raise ImportError( - "scipy is required to use the SO3AveragedCalculator, please install " + "scipy is required to use the `SymmetrizedCalculator`, please install " "it with `pip install scipy` or `conda install scipy`" ) from e - super().__init__(**kwargs) + super().__init__() self.base_calculator = base_calculator if l_max > 131: diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index f3e4884d..9f1f67ff 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -1,26 +1,40 @@ +from typing import Dict, List, Optional, Tuple, Union + import numpy as np import pytest +import torch from ase import Atoms from ase.build import bulk, molecule - +from metatensor.torch import Labels, TensorBlock, TensorMap + +import metatomic.torch as mta +from metatomic.torch import ( + ModelOutput, + NeighborListOptions, + System, +) from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature -def _body_axis_from_atoms(atoms: Atoms) -> np.ndarray: +def _body_axis_from_system(system: System) -> torch.Tensor: """ Return the normalized vector connecting the two farthest atoms. :param atoms: Atomic configuration. :return: Normalized 3D vector defining the body axis. """ - pos = atoms.get_positions() + pos = system.positions if len(pos) < 2: - return np.array([0.0, 0.0, 1.0]) - d2 = np.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) - i, j = np.unravel_index(np.argmax(d2), d2.shape) + return torch.tensor([0.0, 0.0, 1.0], dtype=pos.dtype, device=pos.device) + d2 = torch.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) + i, j = torch.unravel_index(torch.argmax(d2), d2.shape) b = pos[j] - pos[i] - nrm = np.linalg.norm(b) - return b / nrm if nrm > 0 else np.array([0.0, 0.0, 1.0]) + nrm = torch.linalg.norm(b) + return ( + b / nrm + if nrm > 0 + else torch.tensor([0.0, 0.0, 1.0], dtype=pos.dtype, device=pos.device) + ) def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: @@ -32,18 +46,18 @@ def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: """ P0 = 1.0 P1 = c - P2 = 0.5 * (3 * c * c - 1.0) - P3 = 0.5 * (5 * c * c * c - 3 * c) + P2 = 0.5 * (3 * c**2 - 1.0) + P3 = 0.5 * (5 * c**3 - 3 * c) return P0, P1, P2, P3 -class MockAnisoCalculator: +class MockAnisoModel(torch.nn.Module): """ Deterministic, rotation-dependent mock for testing SymmetrizedCalculator. Components: - Energy: E_true + a1*P1 + a2*P2 + a3*P3 - - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*ẑ + optional tensor L=2 term + - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*zhat + optional tensor L=2 term - Stress: p_iso*I + (c2*P2 + c3*P3)*D :param a: Coefficients for Legendre P0..P3 in the energy. @@ -52,82 +66,245 @@ class MockAnisoCalculator: :param p_iso: Isotropic (true) part of the stress tensor. :param tensor_forces: If True, add L=2 tensor-coupled force term. :param tensor_amp: Amplitude of the tensor-coupled force component. + :param dtype: Data type for internal tensors. + :param device: Device for internal tensors. """ def __init__( self, - a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), - b: tuple[float, float, float] = (0.0, 0.0, 0.0), - c: tuple[float, float] = (0.0, 0.0), + a: Tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: Tuple[float, float, float] = (0.0, 0.0, 0.0), + c: Tuple[float, float] = (0.0, 0.0), p_iso: float = 1.0, tensor_forces: bool = False, tensor_amp: float = 0.5, + dtype: torch.dtype = torch.float64, + device: Union[str, torch.device] = "cpu", ) -> None: + super().__init__() self.a0, self.a1, self.a2, self.a3 = a self.b1, self.b2, self.b3 = b self.c2, self.c3 = c self.p_iso = p_iso self.tensor_forces = tensor_forces self.tensor_amp = tensor_amp + self._dtype = dtype + self._device = torch.device(device) + + # Fixed bases + self._zhat = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) + self._D = torch.diag(torch.tensor([1.0, -1.0, 0.0], dtype=dtype, device=device)) - def compute_energy( + @torch.jit.export + def forward( self, - batch: list[Atoms], - compute_forces_and_stresses: bool = False, - ) -> dict[str, list[np.ndarray | float]]: - """ - Compute deterministic, rotation-dependent properties for each batch entry. - - :param batch: List of atomic configurations. - :param compute_forces_and_stresses: Unused flag for API compatibility. - :return: Dictionary with lists of energies, forces, and stresses. - """ - out: dict[str, list[np.ndarray | float]] = { - "energy": [], - "forces": [], - "stress": [], - } - zhat = np.array([0.0, 0.0, 1.0]) - D = np.diag([1.0, -1.0, 0.0]) - - for atoms in batch: - pos = atoms.get_positions() - b = _body_axis_from_atoms(atoms) - c = float(np.dot(b, zhat)) - P0, P1, P2, P3 = _legendre_0_1_2_3(c) + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + n_sys = len(systems) + + # Pre-allocate storages (python lists; torch tensors will be built at the end) + energies: List[float] = [] + stresses: List[torch.Tensor] = [] + forces: List[torch.Tensor] = [] + + for sys in systems: + # Determine body axis and related scalars + b = _body_axis_from_system(sys).to(dtype=self._dtype, device=self._device) + cval = float(torch.dot(b, self._zhat)) + P0, P1, P2, P3 = _legendre_0_1_2_3(cval) + + pos = sys.positions # (Ni, 3) # Energy - E_true = float(np.sum(pos**2)) + E_true = torch.sum(pos**2) E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 # Forces - F_true = pos.copy() - F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * zhat[None, :] + F_true = pos.clone() + F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * self._zhat[None, :] F = F_true + F_spur if self.tensor_forces: - # Build rotation R such that R ẑ = b - v = np.cross(zhat, b) - s = np.linalg.norm(v) - cth = np.dot(zhat, b) + # Build rotation R such that R zhat = b + v = torch.cross(self._zhat, b, dim=0) + s = torch.norm(v) + cth = float(torch.dot(self._zhat, b)) if s < 1e-15: - R = np.eye(3) if cth > 0 else -np.eye(3) + R = ( + torch.eye(3, dtype=self._dtype, device=self._device) + if cth > 0 + else -torch.eye(3, dtype=self._dtype, device=self._device) + ) else: - vx = np.array( - [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]] + vx = torch.tensor( + [[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]], + dtype=self._dtype, + device=self._device, + ) + R = ( + torch.eye(3, dtype=self._dtype, device=self._device) + + vx + + vx @ vx * ((1.0 - cth) / (s**2)) ) - R = np.eye(3) + vx + vx @ vx * ((1 - cth) / (s**2)) - T = R @ D @ R.T - F_tensor = self.tensor_amp * (T @ zhat) + T = R @ self._D @ R.T + F_tensor = self.tensor_amp * (T @ self._zhat) F = F + F_tensor[None, :] # Stress - S = self.p_iso * np.eye(3) + (self.c2 * P2 + self.c3 * P3) * D - - out["energy"].append(E) - out["forces"].append(F) - out["stress"].append(S) - return out + S = ( + self.p_iso * torch.eye(3, dtype=self._dtype, device=self._device) + + (self.c2 * P2 + self.c3 * P3) * self._D + ) + + energies.append(E) + forces.append(F) + stresses.append(S) + + result: Dict[str, TensorMap] = {} + key = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ) + + samples = Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ) + energy_block = TensorBlock( + values=torch.stack(energies) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(1), + samples=samples, + components=[], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + # Forces + samples = Labels( + names=["system", "atom"], + values=torch.cat( + [ + torch.cartesian_prod( + torch.tensor([i], dtype=torch.int64, device=self._device), + torch.arange( + len(systems[i].positions), + dtype=torch.int64, + device=self._device, + ), + ) + for i in range(n_sys) + ] + ), + ) + force_block = TensorBlock( + values=torch.cat(forces, dim=0).unsqueeze(-1), + samples=samples, + components=[ + Labels( + "xyz", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ) + ], + properties=Labels( + names=["non_conservative_forces"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + # Stress + samples = Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ) + print(stresses) + stress_block = TensorBlock( + values=torch.stack(stresses, axis=0).unsqueeze(-1), + samples=samples, + components=[ + Labels( + "xyz_1", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ), + Labels( + "xyz_2", + torch.arange(3) + .reshape(-1, 1) + .to(dtype=torch.int64, device=self._device), + ), + ], + properties=Labels( + names=["non_conservative_stress"], + values=torch.tensor([[0]], dtype=torch.int64, device=self._device), + ), + ) + + result["energy"] = TensorMap(key, [energy_block]) + result["non_conservative_forces"] = TensorMap(key, [force_block]) + result["non_conservative_stress"] = TensorMap(key, [stress_block]) + + return result + + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return [] + + +def mock_calculator( + a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: tuple[float, float, float] = (0.0, 0.0, 0.0), + c: tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, +) -> mta.ase_calculator.MetatomicCalculator: + model = MockAnisoModel( + a=a, + b=b, + c=c, + p_iso=p_iso, + tensor_forces=tensor_forces, + tensor_amp=tensor_amp, + ) + model.eval() + + atomistic_model = mta.AtomisticModel( + model, + mta.ModelMetadata("mock_aniso", "Mock anisotropic model for testing"), + mta.ModelCapabilities( + { + "energy": mta.ModelOutput(per_atom=False), + "non_conservative_forces": mta.ModelOutput(per_atom=True), + "non_conservative_stress": mta.ModelOutput(per_atom=False), + }, + list(range(1, 102)), + 100, + "angstrom", + ["cpu"], + "float64", + ), + ) + return mta.ase_calculator.MetatomicCalculator( + atomistic_model, + non_conservative=True, + do_gradients_with_energy=False, + additional_outputs={ + "energy": mta.ModelOutput(per_atom=False), + "non_conservative_forces": mta.ModelOutput(per_atom=True), + "non_conservative_stress": mta.ModelOutput(per_atom=False), + }, + ) @pytest.fixture @@ -140,7 +317,17 @@ def dimer() -> Atoms: return Atoms("H2", positions=[[0, 0, 0], [0.3, 0.2, 1.0]]) -def test_quadrature_normalization() -> None: +@pytest.fixture +def fcc_bulk() -> Atoms: + """ + Create a small FCC bulk structure. + + :return: ASE Atoms object with FCC Cu. + """ + return bulk("Cu", "fcc", cubic=True) + + +def test_quadrature_normalization(): """Verify normalization and determinant signs of the quadrature.""" R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=True) assert np.isclose(np.sum(w), 1.0) @@ -157,9 +344,10 @@ def test_energy_L_components_removed( For Lmax>0, all use the same minimal Lebedev rule (order=3). """ a = (1.0, 1.0, 1.0, 1.0) - base = MockAnisoCalculator(a=a) + base = mock_calculator(a=a) calc = SymmetrizedCalculator(base, l_max=Lmax) dimer.calc = calc + dimer.get_forces() e = dimer.get_potential_energy() E_true = float(np.sum(dimer.positions**2)) if expect_removed: @@ -174,11 +362,13 @@ def test_force_backrotation_exact(dimer: Atoms) -> None: :param dimer: Test atomic structure. """ - base = MockAnisoCalculator(b=(0, 0, 0)) + base = mock_calculator(b=(0, 0, 0)) calc = SymmetrizedCalculator(base, l_max=3) dimer.calc = calc F = dimer.get_forces() - assert np.allclose(F, dimer.positions, atol=1e-12) + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) + assert np.allclose(F, expected_F, atol=1e-12) def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: @@ -188,27 +378,34 @@ def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: Since the minimal Lebedev order used internally is 3, all quadratures integrate L=2 components exactly; we only check for correct cancellation. """ - base = MockAnisoCalculator(tensor_forces=True, tensor_amp=1.0) + base = mock_calculator(tensor_forces=True, tensor_amp=1.0) for Lmax in [1, 2, 3]: calc = SymmetrizedCalculator(base, l_max=Lmax) dimer.calc = calc F = dimer.get_forces() - assert np.allclose(F, dimer.positions, atol=1e-10) + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) + assert np.allclose(F, expected_F, atol=1e-10) -def test_stress_isotropization(dimer: Atoms) -> None: +def test_stress_isotropization(fcc_bulk: Atoms) -> None: """ Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. :param dimer: Test atomic structure. """ - base = MockAnisoCalculator(c=(1.0, 1.0), p_iso=5.0) - calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) - dimer.calc = calc - S = dimer.get_stress(voigt=False) + base = mock_calculator(c=(2.0, 1.0), p_iso=5.0) + calc = SymmetrizedCalculator(base, l_max=9, include_inversion=True) + fcc_bulk.calc = calc + fcc_bulk.get_forces() + S = fcc_bulk.get_stress(voigt=False) + + fcc_bulk.calc = base + fcc_bulk.get_forces() + iso = np.trace(S) / 3.0 - assert np.allclose(S, np.eye(3) * iso, atol=1e-10) + # assert np.allclose(S, np.eye(3) * iso, atol=1e-10) assert np.isclose(iso, 5.0, atol=1e-10) @@ -218,17 +415,19 @@ def test_cancellation_vs_Lmax(dimer: Atoms) -> None: All quadratures with Lmax>0 are equivalent (Lebedev order=3). """ a = (0.0, 0.0, 1.0, 1.0) - base = MockAnisoCalculator(a=a) + base = mock_calculator(a=a) E_true = float(np.sum(dimer.positions**2)) # No averaging calc0 = SymmetrizedCalculator(base, l_max=0) dimer.calc = calc0 + dimer.get_forces() e0 = dimer.get_potential_energy() # Averaged calc3 = SymmetrizedCalculator(base, l_max=3) dimer.calc = calc3 + dimer.get_forces() e3 = dimer.get_potential_energy() assert not np.isclose(e0, E_true, atol=1e-10) @@ -241,13 +440,15 @@ def test_joint_energy_force_consistency(dimer: Atoms) -> None: :param dimer: Test atomic structure. """ - base = MockAnisoCalculator(a=(1, 1, 1, 1), b=(0, 0, 0)) + base = mock_calculator(a=(1, 1, 1, 1), b=(0, 0, 0)) calc = SymmetrizedCalculator(base, l_max=3) dimer.calc = calc - e = dimer.get_potential_energy() f = dimer.get_forces() + e = dimer.get_potential_energy() + expected_F = dimer.get_positions() + expected_F -= np.mean(expected_F, axis=0) assert np.isclose(e, np.sum(dimer.positions**2) + 1.0, atol=1e-10) - assert np.allclose(f, dimer.positions, atol=1e-12) + assert np.allclose(f, expected_F, atol=1e-12) def test_rotate_atoms_preserves_geometry(tmp_path): @@ -323,7 +524,7 @@ def test_compute_rotational_average_identity(): assert "stress_rot_std" in out -def test_average_over_fcc_group(): +def test_average_over_fcc_group(fcc_bulk: Atoms): """ Check that averaging over the space group of an FCC crystal produces an isotropic (scalar) stress tensor. @@ -334,7 +535,7 @@ def test_average_over_fcc_group(): ) # FCC conventional cubic cell (4 atoms) - atoms = bulk("Cu", "fcc", cubic=True) + atoms = fcc_bulk energy = 0.0 forces = np.random.normal(0, 1, (4, 3)) From 276e4839dbdfc0548553b8ec516121d11834ea8e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 6 Nov 2025 15:41:05 +0100 Subject: [PATCH 26/57] Cleanup --- .../tests/symmetrized_ase_calculator.py | 110 +++++++++--------- 1 file changed, 54 insertions(+), 56 deletions(-) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 9f1f67ff..4d43fbe4 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -104,30 +104,29 @@ def forward( ) -> Dict[str, TensorMap]: n_sys = len(systems) - # Pre-allocate storages (python lists; torch tensors will be built at the end) - energies: List[float] = [] + # Pre-allocate storages + energies: List[torch.Tensor] = [] stresses: List[torch.Tensor] = [] forces: List[torch.Tensor] = [] for sys in systems: + pos = sys.positions + # Determine body axis and related scalars b = _body_axis_from_system(sys).to(dtype=self._dtype, device=self._device) cval = float(torch.dot(b, self._zhat)) P0, P1, P2, P3 = _legendre_0_1_2_3(cval) - pos = sys.positions # (Ni, 3) - # Energy E_true = torch.sum(pos**2) E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 + energies.append(E) # Forces F_true = pos.clone() F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * self._zhat[None, :] F = F_true + F_spur - if self.tensor_forces: - # Build rotation R such that R zhat = b v = torch.cross(self._zhat, b, dim=0) s = torch.norm(v) cth = float(torch.dot(self._zhat, b)) @@ -139,27 +138,25 @@ def forward( ) else: vx = torch.tensor( - [[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]], + [ + [0.0, -v[2], v[1]], + [v[2], 0.0, -v[0]], + [-v[1], v[0], 0.0], + ], dtype=self._dtype, device=self._device, ) - R = ( - torch.eye(3, dtype=self._dtype, device=self._device) - + vx - + vx @ vx * ((1.0 - cth) / (s**2)) - ) + R = torch.eye(3) + vx + vx @ vx * ((1.0 - cth) / (s**2)) T = R @ self._D @ R.T F_tensor = self.tensor_amp * (T @ self._zhat) F = F + F_tensor[None, :] + forces.append(F) # Stress S = ( self.p_iso * torch.eye(3, dtype=self._dtype, device=self._device) + (self.c2 * P2 + self.c3 * P3) * self._D ) - - energies.append(E) - forces.append(F) stresses.append(S) result: Dict[str, TensorMap] = {} @@ -168,17 +165,18 @@ def forward( values=torch.tensor([[0]], dtype=torch.int64, device=self._device), ) - samples = Labels( - names=["system"], - values=torch.arange( - n_sys, dtype=torch.int64, device=self._device - ).unsqueeze(1), - ) + # Energy + print(torch.stack(energies, dim=0).shape) energy_block = TensorBlock( - values=torch.stack(energies) + values=torch.stack(energies, dim=0) .to(dtype=self._dtype, device=self._device) - .unsqueeze(1), - samples=samples, + .unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ), components=[], properties=Labels( names=["energy"], @@ -187,25 +185,20 @@ def forward( ) # Forces - samples = Labels( - names=["system", "atom"], - values=torch.cat( - [ - torch.cartesian_prod( - torch.tensor([i], dtype=torch.int64, device=self._device), - torch.arange( - len(systems[i].positions), - dtype=torch.int64, - device=self._device, - ), - ) - for i in range(n_sys) - ] - ), - ) + print(torch.cat(forces, dim=0).shape) force_block = TensorBlock( - values=torch.cat(forces, dim=0).unsqueeze(-1), - samples=samples, + values=torch.cat(forces, dim=0) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(-1), + samples=Labels( + names=["system", "atom"], + values=torch.cat( + [ + torch.cartesian_prod(torch.tensor([i]), torch.arange(len(sys))) + for i, sys in enumerate(systems) + ] + ).to(dtype=torch.int64, device=self._device), + ), components=[ Labels( "xyz", @@ -213,7 +206,7 @@ def forward( .reshape(-1, 1) .to(dtype=torch.int64, device=self._device), ) - ], + ], # vector components properties=Labels( names=["non_conservative_forces"], values=torch.tensor([[0]], dtype=torch.int64, device=self._device), @@ -221,16 +214,17 @@ def forward( ) # Stress - samples = Labels( - names=["system"], - values=torch.arange( - n_sys, dtype=torch.int64, device=self._device - ).unsqueeze(1), - ) - print(stresses) + print(torch.stack(stresses, dim=0).shape) stress_block = TensorBlock( - values=torch.stack(stresses, axis=0).unsqueeze(-1), - samples=samples, + values=torch.stack(stresses, dim=0) + .to(dtype=self._dtype, device=self._device) + .unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange( + n_sys, dtype=torch.int64, device=self._device + ).unsqueeze(1), + ), components=[ Labels( "xyz_1", @@ -251,9 +245,14 @@ def forward( ), ) - result["energy"] = TensorMap(key, [energy_block]) - result["non_conservative_forces"] = TensorMap(key, [force_block]) - result["non_conservative_stress"] = TensorMap(key, [stress_block]) + if "energy" in outputs: + result["energy"] = TensorMap(key, [energy_block]) + + if "non_conservative_forces" in outputs: + result["non_conservative_forces"] = TensorMap(key, [force_block]) + + if "non_conservative_stress" in outputs: + result["non_conservative_stress"] = TensorMap(key, [stress_block]) return result @@ -405,7 +404,6 @@ def test_stress_isotropization(fcc_bulk: Atoms) -> None: fcc_bulk.get_forces() iso = np.trace(S) / 3.0 - # assert np.allclose(S, np.eye(3) * iso, atol=1e-10) assert np.isclose(iso, 5.0, atol=1e-10) From 712c2aeb27bd101b19a951a41b45d02dbfb95888 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 6 Nov 2025 18:08:18 +0100 Subject: [PATCH 27/57] avoid using new torch functions --- python/metatomic_torch/tests/symmetrized_ase_calculator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 4d43fbe4..d7723438 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -27,7 +27,10 @@ def _body_axis_from_system(system: System) -> torch.Tensor: if len(pos) < 2: return torch.tensor([0.0, 0.0, 1.0], dtype=pos.dtype, device=pos.device) d2 = torch.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) - i, j = torch.unravel_index(torch.argmax(d2), d2.shape) + # i, j = torch.unravel_index(torch.argmax(d2), d2.shape) # for newer PyTorch + idx = torch.argmax(d2) + i = idx // d2.shape[1] + j = idx % d2.shape[1] b = pos[j] - pos[i] nrm = torch.linalg.norm(b) return ( From b309c3108d929ae5a0737bdb88b3bdab2538d165 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 6 Nov 2025 18:10:16 +0100 Subject: [PATCH 28/57] allow for per-atom energies maybe --- .../metatomic/torch/ase_calculator.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3627b39e..508f2cb6 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -892,7 +892,7 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): deviation over the different rotations for each property (e.g., ``energy_std``). """ - implemented_properties = ["energy", "forces", "stress"] + implemented_properties = ["energy", "energies", "forces", "stress", "stresses"] def __init__( self, @@ -1146,10 +1146,16 @@ def _wstd(x): # Energy (B,) if "energy" in results: - E = np.asarray(results["energy"], dtype=float) - out["energy"] = _wmean(E) + E = np.asarray(results["energy"], dtype=float) # (B,) + out["energy"] = _wmean(E) # () if store_std: - out["energy_rot_std"] = _wstd(E) + out["energy_rot_std"] = _wstd(E) # () + + if "energies" in results: + E = np.asarray(results["energies"], dtype=float) # (B,N) + out["energies"] = _wmean(E) # (N,) + if store_std: + out["energies_rot_std"] = _wstd(E) # (N,) # Forces (B,N,3) from rotated structures: back-rotate with F' R if "forces" in results: @@ -1168,6 +1174,14 @@ def _wstd(x): if store_std: out["stress_rot_std"] = _wstd(S_back) # (3,3) + if "stresses" in results: + S = np.asarray(results["stresses"], dtype=float) # (B,N,3,3) + RT = np.swapaxes(R, 1, 2) + S_back = RT[:, None, :, :] @ S @ R[:, None, :, :] # R^T S' R + out["stresses"] = _wmean(S_back) # (N,3,3) + if store_std: + out["stresses_rot_std"] = _wstd(S_back) # (N,3,3) + return out From 4219570701b6302ff061c38390c1c00c5fa3e8ac Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 7 Nov 2025 15:34:24 +0100 Subject: [PATCH 29/57] Allow for per-atom predictions --- .../metatomic/torch/ase_calculator.py | 44 ++++++++++++++++--- .../metatomic_torch/tests/ase_calculator.py | 36 +++++++++++---- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 508f2cb6..cd222323 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -574,6 +574,8 @@ def compute_energy( self, atoms: Union[ase.Atoms, List[ase.Atoms]], compute_forces_and_stresses: bool = False, + *, + compute_energies: bool = False, ) -> Dict[str, Union[Union[float, np.ndarray], List[Union[float, np.ndarray]]]]: """ Compute the energy of the given ``atoms``. @@ -600,8 +602,14 @@ def compute_energy( atoms_list = atoms was_single = False + properties = ["energy"] + energy_per_atom = False + if compute_energies: + energy_per_atom = True + properties.append("energies") + outputs = self._ase_properties_to_metatensor_outputs( - properties=["energy"], + properties=properties, calculate_forces=compute_forces_and_stresses, calculate_stress=compute_forces_and_stresses, calculate_stresses=False, @@ -648,9 +656,26 @@ def compute_energy( ) energies = predictions[self._energy_key] - results_as_numpy_arrays = { - "energy": energies.block().values.detach().cpu().numpy().flatten().tolist() - } + if energy_per_atom: + results_as_numpy_arrays = { + "energies": energies.block().values.squeeze(-1).detach().cpu().numpy(), + "energy": metatensor.torch.sum_over_samples(energies, ["atom"]) + .block() + .values.detach() + .cpu() + .numpy() + .flatten() + .tolist(), + } + split_sizes = [len(system) for system in systems] + split_indices = np.cumsum(split_sizes[:-1]) + results_as_numpy_arrays["energies"] = np.split( + results_as_numpy_arrays["energies"], split_indices, axis=0 + ) + else: + results_as_numpy_arrays = { + "energy": energies.block().values.squeeze(-1).detach().cpu().numpy(), + } if compute_forces_and_stresses: if self.parameters["non_conservative"]: results_as_numpy_arrays["forces"] = ( @@ -663,8 +688,9 @@ def compute_energy( ) # all the forces are concatenated in a single array, so we need to # split them into the original systems - split_sizes = [len(system) for system in systems] - split_indices = np.cumsum(split_sizes[:-1]) + if not energy_per_atom: + split_sizes = [len(system) for system in systems] + split_indices = np.cumsum(split_sizes[:-1]) results_as_numpy_arrays["forces"] = np.split( results_as_numpy_arrays["forces"], split_indices, axis=0 ) @@ -952,8 +978,10 @@ def calculate( ``calculate`` """ super().calculate(atoms, properties, system_changes) + self.base_calculator.calculate(atoms, properties, system_changes) compute_forces_and_stresses = "forces" in properties or "stress" in properties + compute_energies = "energies" in properties if len(self.quadrature_rotations) > 0: rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) @@ -965,7 +993,9 @@ def calculate( for batch in batches: try: batch_results = self.base_calculator.compute_energy( - batch, compute_forces_and_stresses + batch, + compute_forces_and_stresses, + compute_energies=compute_energies, ) for key, value in batch_results.items(): results.setdefault(key, []) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index abc5df43..9d200997 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -277,8 +277,11 @@ def test_run_model(tmpdir, model, atoms): assert outputs["non_conservative_stress"].block().values.shape == (2, 3, 3, 1) -@pytest.mark.parametrize("non_conservative", [True, False]) -def test_compute_energy(tmpdir, model, atoms, non_conservative): +@pytest.mark.parametrize( + "non_conservative, compute_energies", + [(True, True), (False, False), (True, False), (False, True)], +) +def test_compute_energy(tmpdir, model, atoms, non_conservative, compute_energies): ref = atoms.copy() ref.calc = ase.calculators.lj.LennardJones( sigma=SIGMA, epsilon=EPSILON, rc=CUTOFF, ro=CUTOFF, smooth=False @@ -292,23 +295,35 @@ def test_compute_energy(tmpdir, model, atoms, non_conservative): non_conservative=non_conservative, ) - energy = calculator.compute_energy(atoms)["energy"] - assert np.allclose(ref.get_potential_energy(), energy) + results = calculator.compute_energy(atoms, compute_energies=compute_energies) + if compute_energies: + energies = results["energies"] + assert np.allclose(ref.get_potential_energies(), energies) + assert np.allclose(ref.get_potential_energy(), results["energy"]) - results = calculator.compute_energy(atoms, compute_forces_and_stresses=True) + results = calculator.compute_energy( + atoms, compute_forces_and_stresses=True, compute_energies=compute_energies + ) assert np.allclose(ref.get_potential_energy(), results["energy"]) if not non_conservative: assert np.allclose(ref.get_forces(), results["forces"]) assert np.allclose( ref.get_stress(), _full_3x3_to_voigt_6_stress(results["stress"]) ) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"]) - energies = calculator.compute_energy([atoms, atoms])["energy"] - assert np.allclose(ref.get_potential_energy(), energies[0]) - assert np.allclose(ref.get_potential_energy(), energies[1]) + results = calculator.compute_energy([atoms, atoms]) + assert np.allclose(ref.get_potential_energy(), results["energy"][0]) + assert np.allclose(ref.get_potential_energy(), results["energy"][1]) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"][0]) + assert np.allclose(ref.get_potential_energies(), results["energies"][1]) results = calculator.compute_energy( - [atoms, atoms], compute_forces_and_stresses=True + [atoms, atoms], + compute_forces_and_stresses=True, + compute_energies=compute_energies, ) assert np.allclose(ref.get_potential_energy(), results["energy"][0]) assert np.allclose(ref.get_potential_energy(), results["energy"][1]) @@ -321,6 +336,9 @@ def test_compute_energy(tmpdir, model, atoms, non_conservative): assert np.allclose( ref.get_stress(), _full_3x3_to_voigt_6_stress(results["stress"][1]) ) + if compute_energies: + assert np.allclose(ref.get_potential_energies(), results["energies"][0]) + assert np.allclose(ref.get_potential_energies(), results["energies"][1]) atoms_no_pbc = atoms.copy() atoms_no_pbc.pbc = [False, False, False] From 60135e354fe2696470028080e1c279b57b9e858c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Mon, 10 Nov 2025 16:42:22 +0100 Subject: [PATCH 30/57] Fix per atom energies --- .../metatomic/torch/ase_calculator.py | 36 ++++++++++++++----- .../metatomic_torch/tests/ase_calculator.py | 4 ++- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index cd222323..b4eaad3a 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -588,6 +588,8 @@ def compute_energy( :param compute_forces_and_stresses: if ``True``, the model will also compute forces and stresses. IMPORTANT: stresses will only be computed if all provided systems have periodic boundary conditions in all directions. + :param compute_energies: if ``True``, the per-atom energies will also be + computed. :return: A dictionary with the computed properties. The dictionary will contain the ``energy`` as a float, and, if requested, the ``forces`` and ``stress`` @@ -657,8 +659,28 @@ def compute_energy( energies = predictions[self._energy_key] if energy_per_atom: + # Get per-atom energies + sorted_block = metatensor.torch.sort_block(energies.block()) + energies_values = ( + sorted_block.values.detach() + .reshape(-1) + .to(device="cpu") + .to(dtype=torch.float64) + ) + + split_sizes = [len(system) for system in systems] + atom_indices = sorted_block.samples.column("atom") + energies_values = torch.split(energies_values, split_sizes, dim=0) + split_atom_indices = torch.split(atom_indices, split_sizes, dim=0) + split_energies = [] + for atom_indices, values in zip( + split_atom_indices, energies_values, strict=True + ): + split_energy = torch.zeros(len(atom_indices), dtype=values.dtype) + split_energy.index_add_(0, atom_indices, values) + split_energies.append(split_energy) + results_as_numpy_arrays = { - "energies": energies.block().values.squeeze(-1).detach().cpu().numpy(), "energy": metatensor.torch.sum_over_samples(energies, ["atom"]) .block() .values.detach() @@ -666,16 +688,13 @@ def compute_energy( .numpy() .flatten() .tolist(), + "energies": [e.numpy() for e in split_energies], } - split_sizes = [len(system) for system in systems] - split_indices = np.cumsum(split_sizes[:-1]) - results_as_numpy_arrays["energies"] = np.split( - results_as_numpy_arrays["energies"], split_indices, axis=0 - ) else: results_as_numpy_arrays = { "energy": energies.block().values.squeeze(-1).detach().cpu().numpy(), } + if compute_forces_and_stresses: if self.parameters["non_conservative"]: results_as_numpy_arrays["forces"] = ( @@ -688,9 +707,8 @@ def compute_energy( ) # all the forces are concatenated in a single array, so we need to # split them into the original systems - if not energy_per_atom: - split_sizes = [len(system) for system in systems] - split_indices = np.cumsum(split_sizes[:-1]) + split_sizes = [len(system) for system in systems] + split_indices = np.cumsum(split_sizes[:-1]) results_as_numpy_arrays["forces"] = np.split( results_as_numpy_arrays["forces"], split_indices, axis=0 ) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 9d200997..6f2d4d5e 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -313,7 +313,9 @@ def test_compute_energy(tmpdir, model, atoms, non_conservative, compute_energies if compute_energies: assert np.allclose(ref.get_potential_energies(), results["energies"]) - results = calculator.compute_energy([atoms, atoms]) + results = calculator.compute_energy( + [atoms, atoms], compute_energies=compute_energies + ) assert np.allclose(ref.get_potential_energy(), results["energy"][0]) assert np.allclose(ref.get_potential_energy(), results["energy"][1]) if compute_energies: From 59114f62cb8ced768d6b65053b1bd0729b70c6aa Mon Sep 17 00:00:00 2001 From: ppegolo Date: Mon, 10 Nov 2025 18:13:00 +0100 Subject: [PATCH 31/57] Fix a few things --- .../metatomic/torch/ase_calculator.py | 9 +++++++-- .../tests/symmetrized_ase_calculator.py | 15 ++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index b4eaad3a..36e8895c 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -929,7 +929,7 @@ class SymmetrizedCalculator(ase.calculators.calculator.Calculator): the averaging. This is required to average over the full orthogonal group O(3). :param apply_space_group_symmetry: if ``True``, the results will be averaged over discrete space group of rotations for the input system. The group operations are - computed with `spglib `, and the average is + computed with `spglib `_, and the average is performed after the O(3) averaging (if any). This has no effect for non-periodic systems. :param store_rotational_std: if ``True``, the results will contain the standard @@ -1291,7 +1291,12 @@ def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): # Sanity check if np.max(d[j]) > tol: - pass + raise RuntimeError( + ( + f"Sanity check failed in _match_index: max distance {np.max(d[j])} " + f"exceeds tolerance {tol}." + ) + ) return j Q_list, P_list = [], [] diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index d7723438..3ebe2580 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -20,7 +20,7 @@ def _body_axis_from_system(system: System) -> torch.Tensor: """ Return the normalized vector connecting the two farthest atoms. - :param atoms: Atomic configuration. + :param system: System. :return: Normalized 3D vector defining the body axis. """ pos = system.positions @@ -149,7 +149,11 @@ def forward( dtype=self._dtype, device=self._device, ) - R = torch.eye(3) + vx + vx @ vx * ((1.0 - cth) / (s**2)) + R = ( + torch.eye(3, dtype=self._dtype, device=self._device) + + vx + + vx @ vx * ((1.0 - cth) / (s**2)) + ) T = R @ self._D @ R.T F_tensor = self.tensor_amp * (T @ self._zhat) F = F + F_tensor[None, :] @@ -169,7 +173,6 @@ def forward( ) # Energy - print(torch.stack(energies, dim=0).shape) energy_block = TensorBlock( values=torch.stack(energies, dim=0) .to(dtype=self._dtype, device=self._device) @@ -188,7 +191,6 @@ def forward( ) # Forces - print(torch.cat(forces, dim=0).shape) force_block = TensorBlock( values=torch.cat(forces, dim=0) .to(dtype=self._dtype, device=self._device) @@ -217,7 +219,6 @@ def forward( ) # Stress - print(torch.stack(stresses, dim=0).shape) stress_block = TensorBlock( values=torch.stack(stresses, dim=0) .to(dtype=self._dtype, device=self._device) @@ -395,7 +396,7 @@ def test_stress_isotropization(fcc_bulk: Atoms) -> None: """ Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. - :param dimer: Test atomic structure. + :param fcc_bulk: Test atomic structure. """ base = mock_calculator(c=(2.0, 1.0), p_iso=5.0) calc = SymmetrizedCalculator(base, l_max=9, include_inversion=True) @@ -577,7 +578,7 @@ def test_space_group_average_non_periodic(): atoms = molecule("CH4") energy = 0.0 - forces = np.random.normal(0, 1, (4, 3)) + forces = np.random.normal(0, 1, (5, 3)) forces -= np.mean(forces, axis=0) # Ensure zero net force results = {"energy": energy, "forces": forces} From d6ce60a760efaec0116a83d2eae3fd87d6f45d7f Mon Sep 17 00:00:00 2001 From: Paolo Pegolo Date: Tue, 11 Nov 2025 10:53:32 +0100 Subject: [PATCH 32/57] Update docs/src/engines/ase.rst Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/src/engines/ase.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index 8de8b324..2bcdd840 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -25,7 +25,7 @@ Supported model outputs :py:meth:`MetatomicCalculator.run_model`; - for non-equivariant architectures like `PET `_, - rotatonally-averaged energies, forces, and stresses can be computed using + rotationally-averaged energies, forces, and stresses can be computed using :py:class:`metatomic.torch.ase_calculator.SymmetrizedCalculator`. How to install the code From 51292c68be677922626ee1c4f958be21950712f3 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Tue, 11 Nov 2025 16:53:58 +0100 Subject: [PATCH 33/57] Start implementing O(3) mean/std --- .../metatomic/torch/symmetrized_model.py | 401 ++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 python/metatomic_torch/metatomic/torch/symmetrized_model.py diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py new file mode 100644 index 00000000..1334d1e1 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -0,0 +1,401 @@ +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import Labels, TensorMap +from metatrain.utils.augmentation import _apply_augmentations + +from metatomic.torch import ModelOutput, System, register_autograd_neighbors + + +try: + from scipy.integrate import lebedev_rule # noqa: F401 + from scipy.spatial.transform import Rotation # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e +try: + import spherical # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e +try: + import quaternionic # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `quaternionic` package with `pip install quaternionic`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for a Lebedev quadrature combined with in-plane + rotations for SO(3) integration. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (K,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (M*K,) + + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + return A, B, G, w_so3 + + +def _rotations_from_angles( + alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> Rotation: + """ + Compose rotations from ZYZ Euler angles. + + :param alpha: array of alpha angles (M,) + :param beta: array of beta angles (M,) + :param gamma: array of gamma angles (K,) + :return: Rotation object containing all (M*K,) rotations + """ + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", alpha) + * Rotation.from_euler("y", beta) + * Rotation.from_euler("z", gamma) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics + to real spherical harmonics for a given l. + Returns a transformation matrix of shape ((2l+1), (2l+1)). + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + # The size of the transformation matrix is (2l+1) x (2l+1) + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell # Index in the matrix + if m > 0: + # Real part of Y_{l}^{m} + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + # Imaginary part of Y_{l}^{|m|} + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: # m == 0 + # Y_{l}^{0} remains unchanged + T[m_index, ell] = 1 + + # Return the transformation matrix to convert complex to real spherical harmonics + return T + + +def _compute_real_wigner_matrices( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], # alpha, beta, gamma +) -> Dict[int, np.ndarray]: + wigner = spherical.Wigner(o3_lambda_max) + R = quaternionic.array.from_euler_angles(*angles) + D = wigner.D(R) + wigner_D_matrices = {} + for ell in range(o3_lambda_max + 1): + wigner_D_matrices[ell] = np.zeros( + angles[0].shape + (2 * ell + 1, 2 * ell + 1), dtype=np.complex128 + ) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + # There is an unexplained conjugation factor in the definition given in + # the quaternionic library. + wigner_D_matrices[ell][..., mp + ell, m + ell] = ( + D[..., wigner.Dindex(ell, mp, m)] + ).conj() + U = _complex_to_real_spherical_harmonics_transform(ell) + wigner_D_matrices[ell] = np.einsum( + "ij,...jk,kl->...il", U.conj(), wigner_D_matrices[ell], U.T + ) + assert np.allclose(wigner_D_matrices[ell].imag, 0) + wigner_D_matrices[ell] = [ + torch.from_numpy(D) for D in wigner_D_matrices[ell].real + ] + + return wigner_D_matrices + + +class SymmetrizedModel(torch.nn.Module): + + def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): + super().__init__() + self.base_model = base_model + self.max_o3_lambda = max_o3_lambda + self.batch_size = batch_size + + # Compute grid + lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda) + alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( + lebedev_order, n_inplane_rotations + ) + self.so3_weights = torch.from_numpy(w_so3) + + # Active rotations + self.so3_rotations = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ) + self.n_so3_rotations = self.so3_rotations.size(0) + + # Compute inverse Wigner D representations + angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) + self.wigner_D_inverse_rotations = _compute_real_wigner_matrices( + self.max_o3_lambda, angles_inverse_rotations + ) + + # Compute characters + # TODO + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + + # Evaluate the model over the grid + transformed_outputs, backtransformed_outputs = self._eval_over_grid( + systems, outputs, selected_atoms + ) + + mean_std = self._compute_mean_and_variance(backtransformed_outputs) + + return mean_std + + def _compute_mean_and_variance(self, backtransformed_outputs): + mean_std_outputs = {} + for target_name in backtransformed_outputs: + output_so3 = [] + output_pso3 = [] + for i_sys in range(len(backtransformed_outputs[target_name])): + tensor_so3 = backtransformed_outputs[target_name][i_sys][1] + tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] + output_blocks_so3 = [] + output_blocks_pso3 = [] + for key in tensor_so3.keys: + tensor_so3_block = tensor_so3.block(key) + tensor_pso3_block = tensor_pso3.block(key) + w = self.so3_weights.view( + self.n_so3_rotations, *[1] * (tensor_so3_block.values.ndim - 1) + ) + output_blocks_so3.append( + mts.TensorBlock( + samples=tensor_so3_block.samples, + components=tensor_so3_block.components, + properties=tensor_so3_block.properties, + values=w * tensor_so3_block.values, + ) + ) + output_blocks_pso3.append( + mts.TensorBlock( + samples=tensor_pso3_block.samples, + components=tensor_pso3_block.components, + properties=tensor_pso3_block.properties, + values=w * tensor_pso3_block.values, + ) + ) + output_so3.append( + mts.TensorMap( + backtransformed_outputs[target_name][0][1].keys, + output_blocks_so3, + ) + ) + output_pso3.append( + mts.TensorMap( + backtransformed_outputs[target_name][0][-1].keys, + output_blocks_pso3, + ) + ) + output_so3 = mts.join( + output_so3, "samples", add_dimension="physical_system" + ) + output_pso3 = mts.join( + output_pso3, "samples", add_dimension="physical_system" + ) + output = mts.join( + [output_so3, output_pso3], "samples", add_dimension="o3_sigma" + ) + mean = mts.mean_over_samples(output, ["system", "o3_sigma"]) + std = mts.std_over_samples(output, ["system", "o3_sigma"]) + if "physical_system" in mean[0].samples.names: + mean = mts.rename_dimension( + mean, "samples", "physical_system", "system" + ) + std = mts.rename_dimension(std, "samples", "physical_system", "system") + else: + mean = mts.rename_dimension(mean, "samples", "_", "system") + std = mts.rename_dimension(std, "samples", "_", "system") + + mean_std_outputs[target_name + "_mean"] = mean + mean_std_outputs[target_name + "_std"] = std + + return mean_std_outputs + + def _eval_over_grid( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ): + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] for name in outputs + } + backtransformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] for name in outputs + } + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs = [] + for batch in range(0, len(self.so3_rotations), self.batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) + ) + for R in self.so3_rotations[batch : batch + self.batch_size] + ] + out = self.base_model( + transformed_systems, + outputs, + selected_atoms, + ) + rotation_outputs.append(out) + + for name in transformed_outputs: + tensor = mts.join( + [r[name] for r in rotation_outputs], + "samples", + different_keys="error", + ) + transformed_outputs[name][i_sys][inversion] = tensor + + n_rot = self.so3_rotations.size(0) + for name in transformed_outputs: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = transformed_outputs[name][i_sys][inversion] + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + self.so3_rotations.to(device=device, dtype=dtype) + * inversion + ).unbind(0) + ), + self.wigner_D_inverse_rotations, + ) + backtransformed_outputs[name][i_sys][inversion] = backtransformed[ + name + ] + + return transformed_outputs, backtransformed_outputs From 66c4bf76d4f2b0f45064ab8225a090a3e4036b9e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Tue, 11 Nov 2025 18:19:46 +0100 Subject: [PATCH 34/57] Fixes --- .../metatomic/torch/symmetrized_model.py | 119 ++++++++---------- 1 file changed, 54 insertions(+), 65 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 1334d1e1..ec935866 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -3,7 +3,7 @@ import metatensor.torch as mts import numpy as np import torch -from metatensor.torch import Labels, TensorMap +from metatensor.torch import Labels, TensorBlock, TensorMap from metatrain.utils.augmentation import _apply_augmentations from metatomic.torch import ModelOutput, System, register_autograd_neighbors @@ -214,7 +214,6 @@ def _compute_real_wigner_matrices( class SymmetrizedModel(torch.nn.Module): - def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): super().__init__() self.base_model = base_model @@ -236,6 +235,9 @@ def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): # Compute inverse Wigner D representations angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) + self.so3_inverse_rotations = torch.from_numpy( + _rotations_from_angles(*angles_inverse_rotations).as_matrix() + ) self.wigner_D_inverse_rotations = _compute_real_wigner_matrices( self.max_o3_lambda, angles_inverse_rotations ) @@ -249,80 +251,64 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - # Evaluate the model over the grid - transformed_outputs, backtransformed_outputs = self._eval_over_grid( + _, backtransformed_outputs = self._eval_over_grid( systems, outputs, selected_atoms ) mean_std = self._compute_mean_and_variance(backtransformed_outputs) - return mean_std def _compute_mean_and_variance(self, backtransformed_outputs): - mean_std_outputs = {} + mean_std_outputs: Dict[str, TensorMap] = {} + # Iterate over targets for target_name in backtransformed_outputs: - output_so3 = [] - output_pso3 = [] + mean_tensors: List[TensorMap] = [] + std_tensors: List[TensorMap] = [] + # Iterate over systems for i_sys in range(len(backtransformed_outputs[target_name])): tensor_so3 = backtransformed_outputs[target_name][i_sys][1] tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] - output_blocks_so3 = [] - output_blocks_pso3 = [] - for key in tensor_so3.keys: - tensor_so3_block = tensor_so3.block(key) - tensor_pso3_block = tensor_pso3.block(key) + + mean_blocks: List[TensorBlock] = [] + std_blocks: List[TensorBlock] = [] + # Iterate over blocks + for block_so3, block_pso3 in zip(tensor_so3, tensor_pso3, strict=True): w = self.so3_weights.view( - self.n_so3_rotations, *[1] * (tensor_so3_block.values.ndim - 1) + self.n_so3_rotations, *[1] * (block_so3.values.ndim - 1) ) - output_blocks_so3.append( - mts.TensorBlock( - samples=tensor_so3_block.samples, - components=tensor_so3_block.components, - properties=tensor_so3_block.properties, - values=w * tensor_so3_block.values, - ) + mean_block = torch.sum( + (block_so3.values + block_pso3.values) * 0.5 * w, dim=0 ) - output_blocks_pso3.append( - mts.TensorBlock( - samples=tensor_pso3_block.samples, - components=tensor_pso3_block.components, - properties=tensor_pso3_block.properties, - values=w * tensor_pso3_block.values, - ) + second_moment_block = torch.sum( + (block_so3.values**2 + block_pso3.values**2) * 0.5 * w, dim=0 ) - output_so3.append( - mts.TensorMap( - backtransformed_outputs[target_name][0][1].keys, - output_blocks_so3, + std_block = torch.sqrt( + torch.clamp(second_moment_block - mean_block**2, min=0.0) ) - ) - output_pso3.append( - mts.TensorMap( - backtransformed_outputs[target_name][0][-1].keys, - output_blocks_pso3, + mean_blocks.append( + TensorBlock( + samples=Labels("system", torch.tensor([[i_sys]])), + components=block_so3.components, + properties=block_so3.properties, + values=mean_block.unsqueeze(0), + ) ) - ) - output_so3 = mts.join( - output_so3, "samples", add_dimension="physical_system" - ) - output_pso3 = mts.join( - output_pso3, "samples", add_dimension="physical_system" - ) - output = mts.join( - [output_so3, output_pso3], "samples", add_dimension="o3_sigma" - ) - mean = mts.mean_over_samples(output, ["system", "o3_sigma"]) - std = mts.std_over_samples(output, ["system", "o3_sigma"]) - if "physical_system" in mean[0].samples.names: - mean = mts.rename_dimension( - mean, "samples", "physical_system", "system" - ) - std = mts.rename_dimension(std, "samples", "physical_system", "system") - else: - mean = mts.rename_dimension(mean, "samples", "_", "system") - std = mts.rename_dimension(std, "samples", "_", "system") + std_blocks.append( + TensorBlock( + samples=Labels("system", torch.tensor([[i_sys]])), + components=block_so3.components, + properties=block_so3.properties, + values=std_block.unsqueeze(0), + ) + ) + mean_tensors.append(TensorMap(tensor_so3.keys, mean_blocks)) + std_tensors.append(TensorMap(tensor_so3.keys, std_blocks)) + + mean = mts.join(mean_tensors, "samples") + std = mts.join(std_tensors, "samples") + # Store results mean_std_outputs[target_name + "_mean"] = mean mean_std_outputs[target_name + "_std"] = std @@ -347,15 +333,15 @@ def _eval_over_grid( device = systems[0].positions.device dtype = systems[0].positions.dtype - transformed_outputs = { + transformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { name: [{-1: None, 1: None} for _ in systems] for name in outputs } - backtransformed_outputs = { + backtransformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { name: [{-1: None, 1: None} for _ in systems] for name in outputs } for i_sys, system in enumerate(systems): for inversion in [-1, 1]: - rotation_outputs = [] + rotation_outputs: List[Dict[str, TensorMap]] = [] for batch in range(0, len(self.so3_rotations), self.batch_size): transformed_systems = [ _transform_system( @@ -370,13 +356,14 @@ def _eval_over_grid( ) rotation_outputs.append(out) + # Combine batch outputs for name in transformed_outputs: - tensor = mts.join( - [r[name] for r in rotation_outputs], + combined: List[TensorMap] = [r[name] for r in rotation_outputs] + transformed_outputs[name][i_sys][inversion] = mts.join( + combined, "samples", - different_keys="error", + add_dimension="batch_rotation", ) - transformed_outputs[name][i_sys][inversion] = tensor n_rot = self.so3_rotations.size(0) for name in transformed_outputs: @@ -388,7 +375,9 @@ def _eval_over_grid( {name: tensor}, list( ( - self.so3_rotations.to(device=device, dtype=dtype) + self.so3_inverse_rotations.to( + device=device, dtype=dtype + ) * inversion ).unbind(0) ), From 372772a7c70beb8563bfbb5970e825d6a5196c21 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 12 Nov 2025 14:08:26 +0100 Subject: [PATCH 35/57] allow per atom --- .../metatomic/torch/symmetrized_model.py | 111 +++++++++++++----- 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index ec935866..dc587e31 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -256,63 +256,113 @@ def forward( systems, outputs, selected_atoms ) - mean_std = self._compute_mean_and_variance(backtransformed_outputs) - return mean_std + mean_var = self._compute_mean_and_variance(backtransformed_outputs) + return mean_var def _compute_mean_and_variance(self, backtransformed_outputs): - mean_std_outputs: Dict[str, TensorMap] = {} + mean_var_outputs: Dict[str, TensorMap] = {} # Iterate over targets for target_name in backtransformed_outputs: mean_tensors: List[TensorMap] = [] - std_tensors: List[TensorMap] = [] + var_tensors: List[TensorMap] = [] # Iterate over systems for i_sys in range(len(backtransformed_outputs[target_name])): tensor_so3 = backtransformed_outputs[target_name][i_sys][1] tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] mean_blocks: List[TensorBlock] = [] - std_blocks: List[TensorBlock] = [] + var_blocks: List[TensorBlock] = [] # Iterate over blocks for block_so3, block_pso3 in zip(tensor_so3, tensor_pso3, strict=True): - w = self.so3_weights.view( - self.n_so3_rotations, *[1] * (block_so3.values.ndim - 1) + split_by_transformation = torch.bincount( + block_so3.samples.values[:, 0] ) - mean_block = torch.sum( - (block_so3.values + block_pso3.values) * 0.5 * w, dim=0 + w = torch.repeat_interleave( + self.so3_weights, split_by_transformation ) - second_moment_block = torch.sum( - (block_so3.values**2 + block_pso3.values**2) * 0.5 * w, dim=0 - ) - std_block = torch.sqrt( - torch.clamp(second_moment_block - mean_block**2, min=0.0) + w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) + mean_block = (block_so3.values + block_pso3.values) * 0.5 * w + second_moment_block = ( + (block_so3.values**2 + block_pso3.values**2) * 0.5 * w ) mean_blocks.append( TensorBlock( - samples=Labels("system", torch.tensor([[i_sys]])), + samples=block_so3.samples, components=block_so3.components, properties=block_so3.properties, - values=mean_block.unsqueeze(0), + values=mean_block, ) ) - std_blocks.append( + var_blocks.append( TensorBlock( - samples=Labels("system", torch.tensor([[i_sys]])), + samples=block_so3.samples, components=block_so3.components, properties=block_so3.properties, - values=std_block.unsqueeze(0), + values=second_moment_block, ) ) - mean_tensors.append(TensorMap(tensor_so3.keys, mean_blocks)) - std_tensors.append(TensorMap(tensor_so3.keys, std_blocks)) - - mean = mts.join(mean_tensors, "samples") - std = mts.join(std_tensors, "samples") + mean_tensor = mts.sum_over_samples( + TensorMap(tensor_so3.keys, mean_blocks), "system" + ) + second_moment_tensor = mts.sum_over_samples( + TensorMap(tensor_so3.keys, var_blocks), "system" + ) + var_tensor = mts.subtract(second_moment_tensor, mts.pow(mean_tensor, 2)) + mean_tensors.append(mean_tensor) + var_tensors.append(var_tensor) + + mean = mts.join(mean_tensors, "samples", add_dimension="system") + var = mts.join(var_tensors, "samples", add_dimension="system") + + if "system" not in mean[0].samples.names: + mean = mts.insert_dimension( + mean, + "samples", + 0, + "system", + torch.zeros(mean[0].samples.values.shape[0], dtype=torch.long), + ) + var = mts.insert_dimension( + var, + "samples", + 0, + "system", + torch.zeros(var[0].samples.values.shape[0], dtype=torch.long), + ) + else: + num_dims = len(mean[0].samples.names) + mean = mts.permute_dimensions( + mean, + "samples", + [num_dims - 1] + list(range(num_dims - 1)), + ) + var = mts.permute_dimensions( + var, + "samples", + [num_dims - 1] + list(range(num_dims - 1)), + ) + if "_" in mean[0].samples.names: + mean = mts.remove_dimension(mean, "samples", "_") + var = mts.remove_dimension(var, "samples", "_") # Store results - mean_std_outputs[target_name + "_mean"] = mean - mean_std_outputs[target_name + "_std"] = std + mean_var_outputs[target_name + "_mean"] = mean + ncomp = len(var[0].components) + var = TensorMap( + var.keys, + [ + TensorBlock( + samples=block.samples, + components=[], + properties=block.properties, + values=block.values.sum(dim=list(range(1, ncomp + 1))), + ) + for block in var + ], + ) + mean_var_outputs[target_name + "_var"] = var - return mean_std_outputs + return mean_var_outputs def _eval_over_grid( self, @@ -381,7 +431,12 @@ def _eval_over_grid( * inversion ).unbind(0) ), - self.wigner_D_inverse_rotations, + { + ell: self.wigner_D_inverse_rotations[ell] + .to(device=device, dtype=dtype) + .unbind(0) + for ell in self.wigner_D_inverse_rotations + }, ) backtransformed_outputs[name][i_sys][inversion] = backtransformed[ name From b84a73b7d74deaa5e7553bb539b8817d1febe256 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 12 Nov 2025 16:41:57 +0100 Subject: [PATCH 36/57] Seemingly working but slow implementation --- .../metatomic/torch/symmetrized_model.py | 336 +++++++++++++++++- 1 file changed, 330 insertions(+), 6 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index dc587e31..2c6c2331 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -206,13 +206,255 @@ def _compute_real_wigner_matrices( "ij,...jk,kl->...il", U.conj(), wigner_D_matrices[ell], U.T ) assert np.allclose(wigner_D_matrices[ell].imag, 0) - wigner_D_matrices[ell] = [ - torch.from_numpy(D) for D in wigner_D_matrices[ell].real - ] + wigner_D_matrices[ell] = torch.from_numpy(wigner_D_matrices[ell].real) return wigner_D_matrices +def _angles_from_rotations( + R: np.ndarray, + eps: float = 1e-6, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + R01 = R_flat[:, 0, 1] + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + R11 = R_flat[:, 1, 1] + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = np.clip(R22, -1.0, 1.0) + betas = np.arccos(zz) + + # For Z–Y–Z, standard formulas away from the singular set + alphas = np.arctan2(R12, R02) + gammas = np.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * np.pi + alphas = np.mod(alphas, two_pi) + gammas = np.mod(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = np.sin(betas) + near = np.abs(sinb) < eps + if np.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if np.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from + # 2x2 block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = np.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = np.mod(alphas[near_zero], two_pi) + + if np.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. + betas[near_pi] = np.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = np.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = np.mod(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def _euler_angles_of_combined_rotation( + angles1: Tuple[np.ndarray, np.ndarray, np.ndarray], + angles2: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Given two sets of Euler angles (alpha, beta, gamma), returns the Euler angles + of all pairwise compositions + """ + + R1 = _rotations_from_angles(*angles1).as_matrix() # (N1, 3, 3) + R2 = _rotations_from_angles(*angles2).as_matrix() # (N2, 3, 3) + + # Broadcasted pairwise multiplication to shape (N1, N2, 3, 3): R1[p] @ R2[a] + R_product = R1[:, None, :, :] @ R2[None, :, :, :] + + # Extract Euler angles from the combined rotation matrices (robust to gimbal lock) + alpha, beta, gamma = _angles_from_rotations(R_product, eps=1e-6) + return alpha, beta, gamma + + +def _get_so3_character( + alphas: np.ndarray, + betas: np.ndarray, + gammas: np.ndarray, + o3_lambda: int, + tol: float = 1e-7, +) -> np.ndarray: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_l(ω) = sin((2l+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = np.cos(betas / 2.0) * np.cos((alphas + gammas) / 2.0) + cos_t = np.clip(cos_t, -1.0, 1.0) + t = np.arccos(cos_t) + + # Output array + chi = np.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = np.abs(t) < tol + if np.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if np.any(large): + tl = t[large] + sin_t = np.sin(tl) + numer = np.sin(a * tl) + mask = np.abs(sin_t) >= tol + out = np.empty_like(tl) + np.divide(numer, sin_t, out=out, where=mask) + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def _get_o3_character( + alphas: np.ndarray, + betas: np.ndarray, + gammas: np.ndarray, + o3_lambda: int, + o3_sigma: int, + tol: float = 1e-13, +) -> np.ndarray: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over O(3). + """ + return ( + o3_sigma + * ((-1) ** o3_lambda) + * _get_so3_character(alphas, betas, gammas, o3_lambda, tol) + ) + + +def compute_characters( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + inverse_angles: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Dict[int, torch.Tensor]: + alpha, beta, gamma = _euler_angles_of_combined_rotation(angles, inverse_angles) + + so3_characters = { + o3_lambda: _get_so3_character(alpha, beta, gamma, o3_lambda) + for o3_lambda in range(o3_lambda_max + 1) + } + + pso3_characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + pso3_characters[(o3_lambda, o3_sigma)] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_characters[o3_lambda] + ) + + so3_characters = { + key: torch.from_numpy(value) for key, value in so3_characters.items() + } + pso3_characters = { + key: torch.from_numpy(value) for key, value in pso3_characters.items() + } + + return so3_characters, pso3_characters + + +def _integrate_with_character( + tensor_so3: torch.Tensor, + tensor_pso3: torch.Tensor, + so3_characters: Dict[int, torch.Tensor], + pso3_characters: Dict[Tuple[int, int], torch.Tensor], + o3_lambda_max: int, +): + integral = {} + for o3_lambda in range(o3_lambda_max + 1): + so3_character = so3_characters[o3_lambda] + for o3_sigma in [-1, 1]: + pso3_character = pso3_characters[o3_lambda, o3_sigma] + integral[o3_lambda, o3_sigma] = (1 / 4) * ( + torch.einsum( + "i...,i...->...", + tensor_so3, + torch.einsum("ij,j...->i...", so3_character, tensor_so3), + ) + + torch.einsum( + "i...,i...->...", + tensor_pso3, + torch.einsum("ij,j...->i...", pso3_character, tensor_pso3), + ) + ) + (1 / 2) * ( + torch.einsum( + "i...,i...->...", + tensor_so3, + torch.einsum("ij,j...->i...", pso3_character, tensor_pso3), + ) + ) + + # Normalize by Haar measure + integral[(o3_lambda, o3_sigma)] *= (2 * o3_lambda + 1) / ( + 8 * torch.pi**2 + ) ** 2 + return integral + + class SymmetrizedModel(torch.nn.Module): def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): super().__init__() @@ -243,7 +485,11 @@ def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): ) # Compute characters - # TODO + self.so3_characters, self.pso3_characters = compute_characters( + self.max_o3_lambda, + (alpha, beta, gamma), + angles_inverse_rotations, + ) def forward( self, @@ -252,12 +498,90 @@ def forward( selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: # Evaluate the model over the grid - _, backtransformed_outputs = self._eval_over_grid( + transformed_outputs, backtransformed_outputs = self._eval_over_grid( systems, outputs, selected_atoms ) mean_var = self._compute_mean_and_variance(backtransformed_outputs) - return mean_var + character_projections = self._compute_character_projections( + transformed_outputs, mean_var, systems + ) + return mean_var, character_projections + + def _compute_character_projections(self, transformed_outputs, mean_var, systems): + integrals = {} + for name in transformed_outputs: + integrals[name] = [] + for i_sys, tensor_dict in enumerate(transformed_outputs[name]): + integrals[name].append({}) + for (key, block_so3), block_pso3 in zip( + tensor_dict[1].items(), + tensor_dict[-1], + ): + split_by_transformation = torch.bincount( + block_so3.samples.values[:, 0] + ) + w = torch.repeat_interleave( + self.so3_weights, split_by_transformation + ) + w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) + + integral = _integrate_with_character( + block_so3.values * w, + block_pso3.values * w, + self.so3_characters, + self.pso3_characters, + self.max_o3_lambda, + ) + key_dict = tuple(int(k) for k in key.values) + integrals[name][i_sys][key_dict] = integral + + tensors = {} + for name in integrals: + tensors[name] = [] + original_keys = mean_var[name + "_mean"].keys + sample_names = mean_var[name + "_mean"][0].samples.names + for i_sys, integral_per_system in enumerate(integrals[name]): + if "atom" in sample_names: + samples = torch.cartesian_prod( + torch.tensor([i_sys]), + torch.arange(len(systems[i_sys].positions)), + ) + else: + samples = torch.tensor([[i_sys]]) + blocks = {} + for old_key, integral_dict in integral_per_system.items(): + for new_key, integral_values in integral_dict.items(): + full_key = old_key + new_key + blocks[full_key] = integral_values + blocks = TensorMap( + Labels( + original_keys.names + ["ell", "sigma"], + torch.tensor(list(blocks.keys())), + ), + [ + TensorBlock( + values=blocks[key].unsqueeze(0), + samples=Labels(sample_names, samples), + components=mean_var[name + "_mean"] + .block( + {_k: key[i] for i, _k in enumerate(original_keys.names)} + ) + # .block({"o3_lambda": key[0], "o3_sigma": key[1]}) + .components, + properties=mean_var[name + "_mean"] + .block( + {_k: key[i] for i, _k in enumerate(original_keys.names)} + ) + .properties, + ) + for key in blocks + ], + ) + tensors[name].append(blocks) + tensors[name] = mts.join(tensors[name], "samples") + + return tensors def _compute_mean_and_variance(self, backtransformed_outputs): mean_var_outputs: Dict[str, TensorMap] = {} From aeb713b84a6db8700563b2e657cd49a110a4575d Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 14 Nov 2025 10:50:34 +0100 Subject: [PATCH 37/57] A few changes --- .../metatomic/torch/symmetrized_model.py | 482 ++++++++++++++---- 1 file changed, 388 insertions(+), 94 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 2c6c2331..d9cafad4 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -517,6 +517,7 @@ def _compute_character_projections(self, transformed_outputs, mean_var, systems) for (key, block_so3), block_pso3 in zip( tensor_dict[1].items(), tensor_dict[-1], + strict=True, ): split_by_transformation = torch.bincount( block_so3.samples.values[:, 0] @@ -563,8 +564,7 @@ def _compute_character_projections(self, transformed_outputs, mean_var, systems) TensorBlock( values=blocks[key].unsqueeze(0), samples=Labels(sample_names, samples), - components=mean_var[name + "_mean"] - .block( + components=mean_var[name + "_mean"].block( {_k: key[i] for i, _k in enumerate(original_keys.names)} ) # .block({"o3_lambda": key[0], "o3_sigma": key[1]}) @@ -583,116 +583,98 @@ def _compute_character_projections(self, transformed_outputs, mean_var, systems) return tensors - def _compute_mean_and_variance(self, backtransformed_outputs): - mean_var_outputs: Dict[str, TensorMap] = {} - # Iterate over targets - for target_name in backtransformed_outputs: - mean_tensors: List[TensorMap] = [] - var_tensors: List[TensorMap] = [] - # Iterate over systems - for i_sys in range(len(backtransformed_outputs[target_name])): - tensor_so3 = backtransformed_outputs[target_name][i_sys][1] - tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] - - mean_blocks: List[TensorBlock] = [] - var_blocks: List[TensorBlock] = [] - # Iterate over blocks - for block_so3, block_pso3 in zip(tensor_so3, tensor_pso3, strict=True): - split_by_transformation = torch.bincount( - block_so3.samples.values[:, 0] - ) - w = torch.repeat_interleave( - self.so3_weights, split_by_transformation - ) - w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) - mean_block = (block_so3.values + block_pso3.values) * 0.5 * w - second_moment_block = ( - (block_so3.values**2 + block_pso3.values**2) * 0.5 * w - ) - mean_blocks.append( - TensorBlock( - samples=block_so3.samples, - components=block_so3.components, - properties=block_so3.properties, - values=mean_block, - ) - ) - var_blocks.append( - TensorBlock( - samples=block_so3.samples, - components=block_so3.components, - properties=block_so3.properties, - values=second_moment_block, - ) - ) - mean_tensor = mts.sum_over_samples( - TensorMap(tensor_so3.keys, mean_blocks), "system" + def _compute_mean_and_variance( + self, tensor_dict: Dict[str, TensorMap], contract_components: Dict[str, bool] + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + mean_var = {} + for name in contract_components: + tensor = tensor_dict[name] + mean_blocks = [] + second_moment_blocks = [] + mean_norm_blocks = [] + for block in tensor: + rot_ids = block.samples.column("so3_rotation") + + values = block.values + values_norm = ( + torch.norm(values, dim=tuple(range(1, values.ndim - 1))) + if values.ndim > 2 + else torch.abs(values) ) - second_moment_tensor = mts.sum_over_samples( - TensorMap(tensor_so3.keys, var_blocks), "system" - ) - var_tensor = mts.subtract(second_moment_tensor, mts.pow(mean_tensor, 2)) - mean_tensors.append(mean_tensor) - var_tensors.append(var_tensor) + values_squared = values_norm**2 - mean = mts.join(mean_tensors, "samples", add_dimension="system") - var = mts.join(var_tensors, "samples", add_dimension="system") + view = (values.size(0), *[1] * (values.ndim - 1)) + values = 0.5 * self.so3_weights[rot_ids].view(view) * values - if "system" not in mean[0].samples.names: - mean = mts.insert_dimension( - mean, - "samples", - 0, - "system", - torch.zeros(mean[0].samples.values.shape[0], dtype=torch.long), + view = (values_squared.size(0), *[1] * (values_squared.ndim - 1)) + values_squared = ( + 0.5 * self.so3_weights[rot_ids].view(view) * values_squared ) - var = mts.insert_dimension( - var, - "samples", - 0, - "system", - torch.zeros(var[0].samples.values.shape[0], dtype=torch.long), - ) - else: - num_dims = len(mean[0].samples.names) - mean = mts.permute_dimensions( - mean, - "samples", - [num_dims - 1] + list(range(num_dims - 1)), + + view = (values_norm.size(0), *[1] * (values_norm.ndim - 1)) + values_norm = 0.5 * self.so3_weights[rot_ids].view(view) * values_norm + + mean_blocks.append( + TensorBlock( + values=values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) ) - var = mts.permute_dimensions( - var, - "samples", - [num_dims - 1] + list(range(num_dims - 1)), + mean_norm_blocks.append( + TensorBlock( + values=values_norm, + samples=block.samples, + components=[], + properties=block.properties, + ) ) - if "_" in mean[0].samples.names: - mean = mts.remove_dimension(mean, "samples", "_") - var = mts.remove_dimension(var, "samples", "_") - - # Store results - mean_var_outputs[target_name + "_mean"] = mean - ncomp = len(var[0].components) - var = TensorMap( - var.keys, - [ + second_moment_blocks.append( TensorBlock( + values=values_squared, samples=block.samples, components=[], properties=block.properties, - values=block.values.sum(dim=list(range(1, ncomp + 1))), ) - for block in var - ], + ) + + # Mean + tensor_mean = TensorMap(tensor.keys, mean_blocks) + tensor_mean = mts.sum_over_samples( + tensor_mean.keys_to_samples("inversion"), ["inversion", "so3_rotation"] ) - mean_var_outputs[target_name + "_var"] = var - return mean_var_outputs + # Mean norm + tensor_mean_norm = TensorMap(tensor.keys, mean_norm_blocks) + tensor_mean_norm = mts.sum_over_samples( + tensor_mean_norm.keys_to_samples("inversion"), + ["inversion", "so3_rotation"], + ) + + # Second moment + tensor_second_moment = TensorMap(tensor.keys, second_moment_blocks) + tensor_second_moment = mts.sum_over_samples( + tensor_second_moment.keys_to_samples("inversion"), + ["inversion", "so3_rotation"], + ) + + # Variance + tensor_variance = mts.subtract( + tensor_second_moment, mts.pow(tensor_mean_norm, 2) + ) + + mean_var[name + "_mean"] = tensor_mean + mean_var[name + "_mean_norm"] = tensor_mean_norm + mean_var[name + "_var"] = tensor_variance + return mean_var def _eval_over_grid( self, systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels], + return_tensormaps: bool = True, ): """ Sample the model on the O(3) quadrature. @@ -733,11 +715,34 @@ def _eval_over_grid( # Combine batch outputs for name in transformed_outputs: combined: List[TensorMap] = [r[name] for r in rotation_outputs] - transformed_outputs[name][i_sys][inversion] = mts.join( + combined = mts.join( combined, "samples", add_dimension="batch_rotation", ) + if "batch_rotation" in combined[0].samples.names: + # Reindex + blocks = [] + for block in combined: + batch_id = block.samples.column("batch_rotation") + rot_id = block.samples.column("system") + new_sample_values = block.samples.values[:, :-1] + new_sample_values[:, 0] = ( + batch_id * self.batch_size + rot_id + ) + blocks.append( + TensorBlock( + values=block.values, + samples=Labels( + block.samples.names[:-1], + new_sample_values, + ), + components=block.components, + properties=block.properties, + ) + ) + combined = TensorMap(combined.keys, blocks) + transformed_outputs[name][i_sys][inversion] = combined n_rot = self.so3_rotations.size(0) for name in transformed_outputs: @@ -766,4 +771,293 @@ def _eval_over_grid( name ] + if return_tensormaps: + # Massage outputs to have desired shape + for name in transformed_outputs: + joined_plus = mts.join( + [ + transformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [ + transformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension( + joined, "samples", "system", "so3_rotation" + ) + + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension( + joined, "samples", "phys_system", "system" + ) + else: + joined = mts.insert_dimension( + joined, + "samples", + 0, + "system", + torch.zeros( + joined[0].samples.values.shape[0], dtype=torch.long + ), + ) + transformed_outputs[name] = mts.permute_dimensions( + joined, + "samples", + [ + len(joined[0].samples.names) - 1, + *range(len(joined[0].samples.names) - 1), + ], + ) + + joined_plus = mts.join( + [ + backtransformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [ + backtransformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension( + joined, "samples", "system", "so3_rotation" + ) + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension( + joined, "samples", "phys_system", "system" + ) + else: + joined = mts.insert_dimension( + joined, + "samples", + 0, + "system", + torch.zeros( + joined[0].samples.values.shape[0], dtype=torch.long + ), + ) + backtransformed_outputs[name] = mts.permute_dimensions( + joined, + "samples", + [ + len(joined[0].samples.names) - 1, + *range(len(joined[0].samples.names) - 1), + ], + ) + return transformed_outputs, backtransformed_outputs + + # def _compute_mean_and_variance(self, backtransformed_outputs): + # mean_var_outputs: Dict[str, TensorMap] = {} + # # Iterate over targets + # for target_name in backtransformed_outputs: + # mean_tensors: List[TensorMap] = [] + # var_tensors: List[TensorMap] = [] + # # Iterate over systems + # for i_sys in range(len(backtransformed_outputs[target_name])): + # tensor_so3 = backtransformed_outputs[target_name][i_sys][1] + # tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] + + # mean_blocks: List[TensorBlock] = [] + # var_blocks: List[TensorBlock] = [] + # # Iterate over blocks + # for block_so3, block_pso3 in zip(tensor_so3, tensor_pso3, strict=True): + # split_by_transformation = torch.bincount( + # block_so3.samples.values[:, 0] + # ) + # w = torch.repeat_interleave( + # self.so3_weights, split_by_transformation + # ) + # w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) + # mean_block = (block_so3.values + block_pso3.values) * 0.5 * w + # second_moment_block = ( + # (block_so3.values**2 + block_pso3.values**2) * 0.5 * w + # ) + # mean_blocks.append( + # TensorBlock( + # samples=block_so3.samples, + # components=block_so3.components, + # properties=block_so3.properties, + # values=mean_block, + # ) + # ) + # var_blocks.append( + # TensorBlock( + # samples=block_so3.samples, + # components=block_so3.components, + # properties=block_so3.properties, + # values=second_moment_block, + # ) + # ) + # mean_tensor = mts.sum_over_samples( + # TensorMap(tensor_so3.keys, mean_blocks), "system" + # ) + # second_moment_tensor = mts.sum_over_samples( + # TensorMap(tensor_so3.keys, var_blocks), "system" + # ) + # var_tensor = mts.subtract(second_moment_tensor, mts.pow(mean_tensor, 2)) + # mean_tensors.append(mean_tensor) + # var_tensors.append(var_tensor) + + # mean = mts.join(mean_tensors, "samples", add_dimension="system") + # var = mts.join(var_tensors, "samples", add_dimension="system") + + # if "system" not in mean[0].samples.names: + # mean = mts.insert_dimension( + # mean, + # "samples", + # 0, + # "system", + # torch.zeros(mean[0].samples.values.shape[0], dtype=torch.long), + # ) + # var = mts.insert_dimension( + # var, + # "samples", + # 0, + # "system", + # torch.zeros(var[0].samples.values.shape[0], dtype=torch.long), + # ) + # else: + # num_dims = len(mean[0].samples.names) + # mean = mts.permute_dimensions( + # mean, + # "samples", + # [num_dims - 1] + list(range(num_dims - 1)), + # ) + # var = mts.permute_dimensions( + # var, + # "samples", + # [num_dims - 1] + list(range(num_dims - 1)), + # ) + # if "_" in mean[0].samples.names: + # mean = mts.remove_dimension(mean, "samples", "_") + # var = mts.remove_dimension(var, "samples", "_") + + # # Store results + # mean_var_outputs[target_name + "_mean"] = mean + # ncomp = len(var[0].components) + # var = TensorMap( + # var.keys, + # [ + # TensorBlock( + # samples=block.samples, + # components=[], + # properties=block.properties, + # values=block.values.sum(dim=list(range(1, ncomp + 1))), + # ) + # for block in var + # ], + # ) + # mean_var_outputs[target_name + "_var"] = var + + # return mean_var_outputs + + # def _eval_over_grid( + # self, + # systems: List[System], + # outputs: Dict[str, ModelOutput], + # selected_atoms: Optional[Labels], + # ): + # """ + # Sample the model on the O(3) quadrature. + + # :param systems: list of systems to evaluate + # :param model: atomistic model to evaluate + # :param device: device to use for computation + # :return: list of list of model outputs, shape (len(systems), N) + # where N is the number of quadrature points + # """ + + # device = systems[0].positions.device + # dtype = systems[0].positions.dtype + + # transformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { + # name: [{-1: None, 1: None} for _ in systems] for name in outputs + # } + # backtransformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { + # name: [{-1: None, 1: None} for _ in systems] for name in outputs + # } + # for i_sys, system in enumerate(systems): + # for inversion in [-1, 1]: + # rotation_outputs: List[Dict[str, TensorMap]] = [] + # for batch in range(0, len(self.so3_rotations), self.batch_size): + # transformed_systems = [ + # _transform_system( + # system, inversion * R.to(device=device, dtype=dtype) + # ) + # for R in self.so3_rotations[batch : batch + self.batch_size] + # ] + # out = self.base_model( + # transformed_systems, + # outputs, + # selected_atoms, + # ) + # rotation_outputs.append(out) + + # # Combine batch outputs + # for name in transformed_outputs: + # combined: List[TensorMap] = [r[name] for r in rotation_outputs] + # transformed_outputs[name][i_sys][inversion] = mts.join( + # combined, + # "samples", + # add_dimension="batch_rotation", + # ) + + # n_rot = self.so3_rotations.size(0) + # for name in transformed_outputs: + # for i_sys, system in enumerate(systems): + # for inversion in [-1, 1]: + # tensor = transformed_outputs[name][i_sys][inversion] + # _, backtransformed, _ = _apply_augmentations( + # [system] * n_rot, + # {name: tensor}, + # list( + # ( + # self.so3_inverse_rotations.to( + # device=device, dtype=dtype + # ) + # * inversion + # ).unbind(0) + # ), + # { + # ell: self.wigner_D_inverse_rotations[ell] + # .to(device=device, dtype=dtype) + # .unbind(0) + # for ell in self.wigner_D_inverse_rotations + # }, + # ) + # backtransformed_outputs[name][i_sys][inversion] = backtransformed[ + # name + # ] + + # return transformed_outputs, backtransformed_outputs From 5a4a55ac31107f5df45397969f643de60ae775d4 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 14 Nov 2025 10:51:20 +0100 Subject: [PATCH 38/57] Add character projections (to be tested) --- .../metatomic/torch/symmetrized_model.py | 433 ++++++------------ 1 file changed, 137 insertions(+), 296 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index d9cafad4..46285a9d 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -417,42 +417,34 @@ def compute_characters( return so3_characters, pso3_characters -def _integrate_with_character( - tensor_so3: torch.Tensor, - tensor_pso3: torch.Tensor, - so3_characters: Dict[int, torch.Tensor], - pso3_characters: Dict[Tuple[int, int], torch.Tensor], - o3_lambda_max: int, -): - integral = {} - for o3_lambda in range(o3_lambda_max + 1): - so3_character = so3_characters[o3_lambda] - for o3_sigma in [-1, 1]: - pso3_character = pso3_characters[o3_lambda, o3_sigma] - integral[o3_lambda, o3_sigma] = (1 / 4) * ( - torch.einsum( - "i...,i...->...", - tensor_so3, - torch.einsum("ij,j...->i...", so3_character, tensor_so3), - ) - + torch.einsum( - "i...,i...->...", - tensor_pso3, - torch.einsum("ij,j...->i...", pso3_character, tensor_pso3), - ) - ) + (1 / 2) * ( - torch.einsum( - "i...,i...->...", - tensor_so3, - torch.einsum("ij,j...->i...", pso3_character, tensor_pso3), - ) - ) +def _character_times_integrand(chi, samples, components, properties, values): + chi = chi.view(chi.shape + (1,) * (values.ndim - 1)) + broadcast_values = values.view(1, *values.shape) + broadcast = chi * broadcast_values + new_samples = Labels( + ["character"] + samples.names, + torch.hstack( + [ + torch.repeat_interleave( + torch.arange(chi.shape[0]), + repeats=chi.shape[1], + ).unsqueeze(1), + torch.tile(samples.values, dims=(chi.shape[0], 1)), + ] + ), + ) + new_block = TensorBlock( + values=broadcast.reshape(-1, *values.shape[1:]), + samples=new_samples.permute([2, 1, 0] + list(range(3, len(samples.names) + 1))), + components=components, + properties=properties, + ) + new_block = mts.sum_over_samples_block( + new_block, + "so3_rotation", + ) - # Normalize by Haar measure - integral[(o3_lambda, o3_sigma)] *= (2 * o3_lambda + 1) / ( - 8 * torch.pi**2 - ) ** 2 - return integral + return new_block class SymmetrizedModel(torch.nn.Module): @@ -502,92 +494,126 @@ def forward( systems, outputs, selected_atoms ) + # return transformed_outputs, backtransformed_outputs + mean_var = self._compute_mean_and_variance(backtransformed_outputs) - character_projections = self._compute_character_projections( - transformed_outputs, mean_var, systems - ) - return mean_var, character_projections + convolution_integrals = self._compute_conv_integral(transformed_outputs) + + # Compute the character projections + out_dict = mean_var + for name, integral in convolution_integrals.items(): + mean_tensor = mean_var[name + "_mean"] + for key, block in integral.items(): + mean_block = mean_tensor.block( + { + k: int(v) + for k, v in zip( + mean_tensor.keys.names, key.values, strict=False + ) + } + ) + assert block.values.shape == mean_block.values.shape + integral[key].values[:] = mean_block.values - block.values + out_dict[name + "_character_projections"] = integral + + return out_dict + # character_projections = self._compute_character_projections( + # transformed_outputs, mean_var, systems + # ) + # return mean_var, character_projections + + def _compute_conv_integral(self, tensor_dict): + new_tensors = {} + for name, tensor in tensor_dict.items(): + new_blocks = [] + new_keys = [] + for key, block in tensor.items(): + inversion = int(key["inversion"]) + rot_ids = block.samples.column("so3_rotation") - def _compute_character_projections(self, transformed_outputs, mean_var, systems): - integrals = {} - for name in transformed_outputs: - integrals[name] = [] - for i_sys, tensor_dict in enumerate(transformed_outputs[name]): - integrals[name].append({}) - for (key, block_so3), block_pso3 in zip( - tensor_dict[1].items(), - tensor_dict[-1], - strict=True, - ): - split_by_transformation = torch.bincount( - block_so3.samples.values[:, 0] - ) - w = torch.repeat_interleave( - self.so3_weights, split_by_transformation - ) - w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) - - integral = _integrate_with_character( - block_so3.values * w, - block_pso3.values * w, - self.so3_characters, - self.pso3_characters, - self.max_o3_lambda, - ) - key_dict = tuple(int(k) for k in key.values) - integrals[name][i_sys][key_dict] = integral - - tensors = {} - for name in integrals: - tensors[name] = [] - original_keys = mean_var[name + "_mean"].keys - sample_names = mean_var[name + "_mean"][0].samples.names - for i_sys, integral_per_system in enumerate(integrals[name]): - if "atom" in sample_names: - samples = torch.cartesian_prod( - torch.tensor([i_sys]), - torch.arange(len(systems[i_sys].positions)), + key_values = key.values.tolist() + values = block.values + view = (values.size(0), *[1] * (values.ndim - 1)) + values = 0.5 * self.so3_weights[rot_ids].view(view) * values + + for o3_lambda in range(self.max_o3_lambda + 1): + chi_so3 = self.so3_characters[o3_lambda] + + chi_so3_times_integrand = _character_times_integrand( + chi_so3[:, rot_ids], + block.samples, + block.components, + block.properties, + values, ) - else: - samples = torch.tensor([[i_sys]]) - blocks = {} - for old_key, integral_dict in integral_per_system.items(): - for new_key, integral_values in integral_dict.items(): - full_key = old_key + new_key - blocks[full_key] = integral_values - blocks = TensorMap( - Labels( - original_keys.names + ["ell", "sigma"], - torch.tensor(list(blocks.keys())), - ), - [ + contracted = mts.sum_over_samples_block( TensorBlock( - values=blocks[key].unsqueeze(0), - samples=Labels(sample_names, samples), - components=mean_var[name + "_mean"].block( - {_k: key[i] for i, _k in enumerate(original_keys.names)} + samples=block.samples, + components=block.components, + properties=block.properties, + values=0.25 + * values + * chi_so3_times_integrand.values + * (2 * o3_lambda + 1) + / (8 * torch.pi**2) ** 2, + ), + "so3_rotation", + ) + new_keys.append(key_values + [o3_lambda, 1, 1]) + new_blocks.append(contracted) + new_keys.append(key_values + [o3_lambda, -1, 1]) + new_blocks.append(contracted) + + if inversion == -1: + for o3_sigma in [1, -1]: + chi_pso3 = self.pso3_characters[(o3_lambda, o3_sigma)] + chi_pso3_times_integrand = _character_times_integrand( + chi_pso3[:, rot_ids], + block.samples, + block.components, + block.properties, + values, ) - # .block({"o3_lambda": key[0], "o3_sigma": key[1]}) - .components, - properties=mean_var[name + "_mean"] - .block( - {_k: key[i] for i, _k in enumerate(original_keys.names)} + so3_key = { + k: int(v) + for k, v in zip(key.names, key.values, strict=True) + } + so3_key["inversion"] = 1 + so3_block = tensor.block(so3_key) + contracted = mts.sum_over_samples_block( + TensorBlock( + samples=so3_block.samples, + components=so3_block.components, + properties=so3_block.properties, + values=0.5 + * so3_block.values + * chi_pso3_times_integrand.values + * (2 * o3_lambda + 1) + / (8 * torch.pi**2) ** 2, + ), + "so3_rotation", ) - .properties, - ) - for key in blocks - ], - ) - tensors[name].append(blocks) - tensors[name] = mts.join(tensors[name], "samples") - - return tensors + new_keys.append(key_values + [o3_lambda, o3_sigma, 2]) + new_blocks.append(contracted) + + new_tensor = TensorMap( + Labels( + names=tensor.keys.names + ["chi_lambda", "chi_sigma", "term"], + values=torch.tensor(new_keys), + ), + new_blocks, + ) + new_tensors[name] = mts.sum_over_samples( + new_tensor.keys_to_samples(["term", "inversion"]), ["term", "inversion"] + ) + return new_tensors def _compute_mean_and_variance( - self, tensor_dict: Dict[str, TensorMap], contract_components: Dict[str, bool] + self, + tensor_dict: Dict[str, TensorMap], ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: mean_var = {} - for name in contract_components: + for name in tensor_dict: tensor = tensor_dict[name] mean_blocks = [] second_moment_blocks = [] @@ -876,188 +902,3 @@ def _eval_over_grid( ) return transformed_outputs, backtransformed_outputs - - # def _compute_mean_and_variance(self, backtransformed_outputs): - # mean_var_outputs: Dict[str, TensorMap] = {} - # # Iterate over targets - # for target_name in backtransformed_outputs: - # mean_tensors: List[TensorMap] = [] - # var_tensors: List[TensorMap] = [] - # # Iterate over systems - # for i_sys in range(len(backtransformed_outputs[target_name])): - # tensor_so3 = backtransformed_outputs[target_name][i_sys][1] - # tensor_pso3 = backtransformed_outputs[target_name][i_sys][-1] - - # mean_blocks: List[TensorBlock] = [] - # var_blocks: List[TensorBlock] = [] - # # Iterate over blocks - # for block_so3, block_pso3 in zip(tensor_so3, tensor_pso3, strict=True): - # split_by_transformation = torch.bincount( - # block_so3.samples.values[:, 0] - # ) - # w = torch.repeat_interleave( - # self.so3_weights, split_by_transformation - # ) - # w = w.view(w.shape[0], *[1] * (block_so3.values.ndim - 1)) - # mean_block = (block_so3.values + block_pso3.values) * 0.5 * w - # second_moment_block = ( - # (block_so3.values**2 + block_pso3.values**2) * 0.5 * w - # ) - # mean_blocks.append( - # TensorBlock( - # samples=block_so3.samples, - # components=block_so3.components, - # properties=block_so3.properties, - # values=mean_block, - # ) - # ) - # var_blocks.append( - # TensorBlock( - # samples=block_so3.samples, - # components=block_so3.components, - # properties=block_so3.properties, - # values=second_moment_block, - # ) - # ) - # mean_tensor = mts.sum_over_samples( - # TensorMap(tensor_so3.keys, mean_blocks), "system" - # ) - # second_moment_tensor = mts.sum_over_samples( - # TensorMap(tensor_so3.keys, var_blocks), "system" - # ) - # var_tensor = mts.subtract(second_moment_tensor, mts.pow(mean_tensor, 2)) - # mean_tensors.append(mean_tensor) - # var_tensors.append(var_tensor) - - # mean = mts.join(mean_tensors, "samples", add_dimension="system") - # var = mts.join(var_tensors, "samples", add_dimension="system") - - # if "system" not in mean[0].samples.names: - # mean = mts.insert_dimension( - # mean, - # "samples", - # 0, - # "system", - # torch.zeros(mean[0].samples.values.shape[0], dtype=torch.long), - # ) - # var = mts.insert_dimension( - # var, - # "samples", - # 0, - # "system", - # torch.zeros(var[0].samples.values.shape[0], dtype=torch.long), - # ) - # else: - # num_dims = len(mean[0].samples.names) - # mean = mts.permute_dimensions( - # mean, - # "samples", - # [num_dims - 1] + list(range(num_dims - 1)), - # ) - # var = mts.permute_dimensions( - # var, - # "samples", - # [num_dims - 1] + list(range(num_dims - 1)), - # ) - # if "_" in mean[0].samples.names: - # mean = mts.remove_dimension(mean, "samples", "_") - # var = mts.remove_dimension(var, "samples", "_") - - # # Store results - # mean_var_outputs[target_name + "_mean"] = mean - # ncomp = len(var[0].components) - # var = TensorMap( - # var.keys, - # [ - # TensorBlock( - # samples=block.samples, - # components=[], - # properties=block.properties, - # values=block.values.sum(dim=list(range(1, ncomp + 1))), - # ) - # for block in var - # ], - # ) - # mean_var_outputs[target_name + "_var"] = var - - # return mean_var_outputs - - # def _eval_over_grid( - # self, - # systems: List[System], - # outputs: Dict[str, ModelOutput], - # selected_atoms: Optional[Labels], - # ): - # """ - # Sample the model on the O(3) quadrature. - - # :param systems: list of systems to evaluate - # :param model: atomistic model to evaluate - # :param device: device to use for computation - # :return: list of list of model outputs, shape (len(systems), N) - # where N is the number of quadrature points - # """ - - # device = systems[0].positions.device - # dtype = systems[0].positions.dtype - - # transformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { - # name: [{-1: None, 1: None} for _ in systems] for name in outputs - # } - # backtransformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { - # name: [{-1: None, 1: None} for _ in systems] for name in outputs - # } - # for i_sys, system in enumerate(systems): - # for inversion in [-1, 1]: - # rotation_outputs: List[Dict[str, TensorMap]] = [] - # for batch in range(0, len(self.so3_rotations), self.batch_size): - # transformed_systems = [ - # _transform_system( - # system, inversion * R.to(device=device, dtype=dtype) - # ) - # for R in self.so3_rotations[batch : batch + self.batch_size] - # ] - # out = self.base_model( - # transformed_systems, - # outputs, - # selected_atoms, - # ) - # rotation_outputs.append(out) - - # # Combine batch outputs - # for name in transformed_outputs: - # combined: List[TensorMap] = [r[name] for r in rotation_outputs] - # transformed_outputs[name][i_sys][inversion] = mts.join( - # combined, - # "samples", - # add_dimension="batch_rotation", - # ) - - # n_rot = self.so3_rotations.size(0) - # for name in transformed_outputs: - # for i_sys, system in enumerate(systems): - # for inversion in [-1, 1]: - # tensor = transformed_outputs[name][i_sys][inversion] - # _, backtransformed, _ = _apply_augmentations( - # [system] * n_rot, - # {name: tensor}, - # list( - # ( - # self.so3_inverse_rotations.to( - # device=device, dtype=dtype - # ) - # * inversion - # ).unbind(0) - # ), - # { - # ell: self.wigner_D_inverse_rotations[ell] - # .to(device=device, dtype=dtype) - # .unbind(0) - # for ell in self.wigner_D_inverse_rotations - # }, - # ) - # backtransformed_outputs[name][i_sys][inversion] = backtransformed[ - # name - # ] - - # return transformed_outputs, backtransformed_outputs From 3065438308bd3bfa211c8e71f7fc881245c3aab8 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 14 Nov 2025 15:44:11 +0100 Subject: [PATCH 39/57] Make faster --- .../metatomic/torch/symmetrized_model.py | 224 +++++++++--------- 1 file changed, 108 insertions(+), 116 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 46285a9d..b22b4ed9 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -418,41 +418,80 @@ def compute_characters( def _character_times_integrand(chi, samples, components, properties, values): - chi = chi.view(chi.shape + (1,) * (values.ndim - 1)) - broadcast_values = values.view(1, *values.shape) - broadcast = chi * broadcast_values - new_samples = Labels( - ["character"] + samples.names, - torch.hstack( - [ - torch.repeat_interleave( - torch.arange(chi.shape[0]), - repeats=chi.shape[1], - ).unsqueeze(1), - torch.tile(samples.values, dims=(chi.shape[0], 1)), - ] - ), + n_rot = chi.size(0) + + reshaped_values = values.reshape(n_rot, -1, *values.shape[1:]) + contracted_values = (chi @ reshaped_values).reshape( + chi.shape[0], *reshaped_values.shape[1:] ) + names = samples.names + names.pop("so3_rotation") new_block = TensorBlock( - values=broadcast.reshape(-1, *values.shape[1:]), - samples=new_samples.permute([2, 1, 0] + list(range(3, len(samples.names) + 1))), + samples=samples.view(names).to_owned(), components=components, properties=properties, + values=contracted_values, + ) + + return new_block + + +def _character_convolution(chi, block1, block2, w): + samples = block1.samples + components = block1.components + properties = block1.properties + values = block1.values + chi = chi.to(dtype=values.dtype, device=values.device) + n_rot = chi.size(1) + weight = 0.5 * w.to(dtype=values.dtype, device=values.device) + + reshaped_values = values.reshape(n_rot, -1, *values.shape[1:]) + # element-wise multiplication of weights and reshaped values, which have generic shape + weighted_values = ( + weight.view(-1, *([1] * (reshaped_values.dim() - 1))) * reshaped_values + ) + contracted_values = ( + chi @ weighted_values.reshape(weighted_values.shape[0], -1) + ).reshape(chi.shape[0], *weighted_values.shape[1:]) + + values2 = block2.values + reshaped_values2 = values2.reshape(n_rot, -1, *values2.shape[1:]) + weighted_values2 = ( + weight.view(-1, *([1] * (reshaped_values2.dim() - 1))) * reshaped_values2 ) - new_block = mts.sum_over_samples_block( - new_block, - "so3_rotation", + contracted_values = torch.einsum( + "i...,i...->...", + weighted_values2, + contracted_values, + ) + + names = samples.names + names = [name for name in names if name != "so3_rotation"] + new_block = TensorBlock( + samples=Labels(names, samples.values[samples.values[:, 0] == 0][:, 1:]), + components=components, + properties=properties, + values=contracted_values, ) return new_block class SymmetrizedModel(torch.nn.Module): - def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): + def __init__( + self, + base_model, + max_o3_lambda, + batch_size: int = 32, + max_o3_lambda_character: Optional[int] = None, + ): super().__init__() self.base_model = base_model self.max_o3_lambda = max_o3_lambda self.batch_size = batch_size + if max_o3_lambda_character is None: + max_o3_lambda_character = max_o3_lambda + self.max_o3_lambda_character = max_o3_lambda_character # Compute grid lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda) @@ -478,7 +517,7 @@ def __init__(self, base_model, max_o3_lambda, batch_size: int = 32): # Compute characters self.so3_characters, self.pso3_characters = compute_characters( - self.max_o3_lambda, + self.max_o3_lambda_character, (alpha, beta, gamma), angles_inverse_rotations, ) @@ -494,8 +533,6 @@ def forward( systems, outputs, selected_atoms ) - # return transformed_outputs, backtransformed_outputs - mean_var = self._compute_mean_and_variance(backtransformed_outputs) convolution_integrals = self._compute_conv_integral(transformed_outputs) @@ -514,98 +551,67 @@ def forward( ) assert block.values.shape == mean_block.values.shape integral[key].values[:] = mean_block.values - block.values + if "_" in integral.keys.names: + integral = mts.remove_dimension(integral, "keys", "_") out_dict[name + "_character_projections"] = integral return out_dict - # character_projections = self._compute_character_projections( - # transformed_outputs, mean_var, systems - # ) - # return mean_var, character_projections def _compute_conv_integral(self, tensor_dict): new_tensors = {} for name, tensor in tensor_dict.items(): + keys = tensor.keys + remaining_keys = Labels( + keys.names[:-1], + keys.values[keys.column("inversion") == 1][:, :-1], + ) new_blocks = [] new_keys = [] - for key, block in tensor.items(): - inversion = int(key["inversion"]) - rot_ids = block.samples.column("so3_rotation") - - key_values = key.values.tolist() - values = block.values - view = (values.size(0), *[1] * (values.ndim - 1)) - values = 0.5 * self.so3_weights[rot_ids].view(view) * values - - for o3_lambda in range(self.max_o3_lambda + 1): - chi_so3 = self.so3_characters[o3_lambda] - - chi_so3_times_integrand = _character_times_integrand( - chi_so3[:, rot_ids], - block.samples, - block.components, - block.properties, - values, + for key in remaining_keys: + key_vals = { + k: int(v) for k, v in zip(key.names, key.values, strict=True) + } + so3_block = tensor.block(key_vals | {"inversion": 1}) + pso3_block = tensor.block(key_vals | {"inversion": -1}) + + for o3_lambda in range(self.max_o3_lambda_character + 1): + so3_chi = self.so3_characters[o3_lambda] + first_term = _character_convolution( + so3_chi, so3_block, so3_block, self.so3_weights ) - contracted = mts.sum_over_samples_block( - TensorBlock( - samples=block.samples, - components=block.components, - properties=block.properties, - values=0.25 - * values - * chi_so3_times_integrand.values - * (2 * o3_lambda + 1) - / (8 * torch.pi**2) ** 2, - ), - "so3_rotation", + second_term = _character_convolution( + so3_chi, pso3_block, pso3_block, self.so3_weights ) - new_keys.append(key_values + [o3_lambda, 1, 1]) - new_blocks.append(contracted) - new_keys.append(key_values + [o3_lambda, -1, 1]) - new_blocks.append(contracted) - - if inversion == -1: - for o3_sigma in [1, -1]: - chi_pso3 = self.pso3_characters[(o3_lambda, o3_sigma)] - chi_pso3_times_integrand = _character_times_integrand( - chi_pso3[:, rot_ids], - block.samples, - block.components, - block.properties, - values, - ) - so3_key = { - k: int(v) - for k, v in zip(key.names, key.values, strict=True) - } - so3_key["inversion"] = 1 - so3_block = tensor.block(so3_key) - contracted = mts.sum_over_samples_block( - TensorBlock( - samples=so3_block.samples, - components=so3_block.components, - properties=so3_block.properties, - values=0.5 - * so3_block.values - * chi_pso3_times_integrand.values - * (2 * o3_lambda + 1) - / (8 * torch.pi**2) ** 2, - ), - "so3_rotation", - ) - new_keys.append(key_values + [o3_lambda, o3_sigma, 2]) - new_blocks.append(contracted) + for o3_sigma in [1, -1]: + pso3_chi = self.pso3_characters[(o3_lambda, o3_sigma)] + third_term = _character_convolution( + pso3_chi, so3_block, pso3_block, self.so3_weights + ) + block = TensorBlock( + samples=first_term.samples, + components=first_term.components, + properties=first_term.properties, + values=( + 0.25 * (first_term.values + second_term.values) + + 0.5 * third_term.values + ) + * (2 * o3_lambda + 1) + / (8 * torch.pi**2) ** 2, + ) + new_blocks.append(block) + new_keys.append( + torch.cat([key.values, torch.tensor([o3_lambda, o3_sigma])]) + ) + key_names = [name for name in tensor.keys.names if name != "inversion"] new_tensor = TensorMap( Labels( - names=tensor.keys.names + ["chi_lambda", "chi_sigma", "term"], - values=torch.tensor(new_keys), + key_names + ["chi_lambda", "chi_sigma"], + torch.stack(new_keys), ), new_blocks, ) - new_tensors[name] = mts.sum_over_samples( - new_tensor.keys_to_samples(["term", "inversion"]), ["term", "inversion"] - ) + new_tensors[name] = new_tensor return new_tensors def _compute_mean_and_variance( @@ -836,20 +842,13 @@ def _eval_over_grid( joined = mts.insert_dimension( joined, "samples", - 0, + 1, "system", torch.zeros( joined[0].samples.values.shape[0], dtype=torch.long ), ) - transformed_outputs[name] = mts.permute_dimensions( - joined, - "samples", - [ - len(joined[0].samples.names) - 1, - *range(len(joined[0].samples.names) - 1), - ], - ) + transformed_outputs[name] = joined joined_plus = mts.join( [ @@ -886,19 +885,12 @@ def _eval_over_grid( joined = mts.insert_dimension( joined, "samples", - 0, + 1, "system", torch.zeros( joined[0].samples.values.shape[0], dtype=torch.long ), ) - backtransformed_outputs[name] = mts.permute_dimensions( - joined, - "samples", - [ - len(joined[0].samples.names) - 1, - *range(len(joined[0].samples.names) - 1), - ], - ) + backtransformed_outputs[name] = joined return transformed_outputs, backtransformed_outputs From c9f6a516e83f0f17374740534b095ef18883f9b4 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Tue, 18 Nov 2025 12:31:43 +0100 Subject: [PATCH 40/57] Make torchscriptable and add docstrings --- .../metatomic/torch/symmetrized_model.py | 498 +++++++++++------- 1 file changed, 321 insertions(+), 177 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index b22b4ed9..a3f9ff86 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -246,10 +246,10 @@ def _angles_from_rotations( # Read commonly-used entries with explicit names for readability R00 = R_flat[:, 0, 0] - R01 = R_flat[:, 0, 1] + # R01 = R_flat[:, 0, 1] R02 = R_flat[:, 0, 2] R10 = R_flat[:, 1, 0] - R11 = R_flat[:, 1, 1] + # R11 = R_flat[:, 1, 1] R12 = R_flat[:, 1, 2] R20 = R_flat[:, 2, 0] R21 = R_flat[:, 2, 1] @@ -392,7 +392,7 @@ def compute_characters( o3_lambda_max: int, angles: Tuple[np.ndarray, np.ndarray, np.ndarray], inverse_angles: Tuple[np.ndarray, np.ndarray, np.ndarray], -) -> Dict[int, torch.Tensor]: +) -> Tuple[Dict[int, torch.Tensor], Dict[str, torch.Tensor]]: alpha, beta, gamma = _euler_angles_of_combined_rotation(angles, inverse_angles) so3_characters = { @@ -403,7 +403,7 @@ def compute_characters( pso3_characters = {} for o3_lambda in range(o3_lambda_max + 1): for o3_sigma in [-1, +1]: - pso3_characters[(o3_lambda, o3_sigma)] = ( + pso3_characters[f"{o3_lambda}_{o3_sigma}"] = ( o3_sigma * ((-1) ** o3_lambda) * so3_characters[o3_lambda] ) @@ -417,26 +417,13 @@ def compute_characters( return so3_characters, pso3_characters -def _character_times_integrand(chi, samples, components, properties, values): - n_rot = chi.size(0) - - reshaped_values = values.reshape(n_rot, -1, *values.shape[1:]) - contracted_values = (chi @ reshaped_values).reshape( - chi.shape[0], *reshaped_values.shape[1:] - ) - names = samples.names - names.pop("so3_rotation") - new_block = TensorBlock( - samples=samples.view(names).to_owned(), - components=components, - properties=properties, - values=contracted_values, - ) - - return new_block - - -def _character_convolution(chi, block1, block2, w): +def _character_convolution( + chi: torch.Tensor, block1: TensorBlock, block2: TensorBlock, w: torch.Tensor +) -> TensorBlock: + """ + Compute the character convolution between a block containing SO(3)-sampled tensors. + Then contract with another block. + """ samples = block1.samples components = block1.components properties = block1.properties @@ -445,28 +432,36 @@ def _character_convolution(chi, block1, block2, w): n_rot = chi.size(1) weight = 0.5 * w.to(dtype=values.dtype, device=values.device) - reshaped_values = values.reshape(n_rot, -1, *values.shape[1:]) - # element-wise multiplication of weights and reshaped values, which have generic shape - weighted_values = ( - weight.view(-1, *([1] * (reshaped_values.dim() - 1))) * reshaped_values - ) + new_shape = [n_rot, -1] + list(values.shape[1:]) + reshaped_values = values.reshape(new_shape) + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values.ndim - 1): + view.append(1) + weighted_values = weight.view(view) * reshaped_values + contracted_shape: List[int] = [chi.shape[0]] + list(weighted_values.shape[1:]) contracted_values = ( chi @ weighted_values.reshape(weighted_values.shape[0], -1) - ).reshape(chi.shape[0], *weighted_values.shape[1:]) + ).reshape(contracted_shape) values2 = block2.values - reshaped_values2 = values2.reshape(n_rot, -1, *values2.shape[1:]) - weighted_values2 = ( - weight.view(-1, *([1] * (reshaped_values2.dim() - 1))) * reshaped_values2 - ) + new_shape = [n_rot, -1] + list(values2.shape[1:]) + reshaped_values2 = values2.reshape(new_shape) + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values2.ndim - 1): + view.append(1) + weighted_values2 = weight.view(view) * reshaped_values2 contracted_values = torch.einsum( "i...,i...->...", weighted_values2, contracted_values, ) - names = samples.names - names = [name for name in names if name != "so3_rotation"] + names: List[str] = [] + for name in samples.names: + if name != "so3_rotation": + names.append(name) new_block = TensorBlock( samples=Labels(names, samples.values[samples.values[:, 0] == 0][:, 1:]), components=components, @@ -478,6 +473,81 @@ def _character_convolution(chi, block1, block2, w): class SymmetrizedModel(torch.nn.Module): + """ + Wrapper around an atomistic model that symmetrizes its outputs over :math:`O(3)` + and computes equivariance metrics. + + The model is evaluated over a quadrature grid on :math:`O(3)`, constructed from a + Lebedev grid supplemented by in-plane rotations. For each sampled group element, the + model outputs are "back-rotated" according to the known :math:`O(3)` action + appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these + back-rotated predictions over the quadrature grid yields fully + :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance + metrics are computed: + + 1. Variance under :math:`O(3)` of the back-rotated outputs. + + For a perfectly equivariant model, the back-rotated output :math:`x(g)` is + independent of the group element :math:`g`. Deviations from perfect equivariance + are quantified by the difference between the average squared norm over + :math:`O(3)` and the squared norm of the :math:`O(3)`-averaged output: + + .. math:: + + \mathrm{Var}_{O(3)}[x] + = + \left\langle \,\| x(g) \|^{2} \,\right\rangle_{O(3)} + - + \left\| \left\langle x(g) \right\rangle_{O(3)} \right\|^{2} . + + Here, :math:`\|\cdot\|` denotes the Euclidean norm over the ``component`` axis, + and :math:`\langle \cdot \rangle_{O(3)}` denotes averaging over the quadrature + grid. This quantity is the squared norm of the component orthogonal to the + perfectly equivariant subspace and therefore provides a scalar measure of the + deviation from exact equivariance. + + 2. Decomposition into isotypical components of :math:`O(3)`. + + Each output component may be viewed as a scalar function on :math:`O(3)`, + which can be decomposed into isotypical components labeled by the irreducible + representations :math:`\ell,\sigma` of :math:`O(3)`. The projection onto the + :math:`(\ell,\sigma)`-th isotypical subspace is computed as a convolution with + the corresponding character :math:`\chi_{\ell}`: + + .. math:: + + (P_{\ell,\sigma} x)(g) + = + \int_{O(3)} \chi_{\ell,\sigma}(h^{-1} g)\, x(h)\, \mathrm{d}\mu(h). + + Its squared :math:`L^{2}` norm over :math:`O(3)` is + + .. math:: + + \| P_{\ell,\sigma} x \|^{2} + \;=\; + \left\langle \, | (P_{\ell,\sigma} x)(g) |^{2} \, \right\rangle_{O(3)} . + + These quantities describe how the model output is distributed across the + different :math:`O(3)` irreducible sectors. The complementary component, + orthogonal to all isotypical subspaces, is given by + + .. math:: + + \| x \|^{2} + - + \sum_{\ell,\sigma} \| P_{\ell,\sigma} x \|^{2} , + + and provides a refined measure of the deviation from lying entirely within any + prescribed set of :math:`O(3)` irreducible representations. + + :param base_model: atomistic model to symmetrize + :param max_o3_lambda: maximum O(3) angular momentum the grid integrates exactly + :param batch_size: number of rotations to evaluate in a single batch + :param max_o3_lambda_character: maximum O(3) angular momentum for character + projections. If None, set to ``max_o3_lambda``. + """ + def __init__( self, base_model, @@ -528,6 +598,15 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: + """ + Symmetrize the model outputs over :math:`O(3)` and compute equivariance + metrics. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :return: dictionary with symmetrized outputs and equivariance metrics + """ # Evaluate the model over the grid transformed_outputs, backtransformed_outputs = self._eval_over_grid( systems, outputs, selected_atoms @@ -557,22 +636,37 @@ def forward( return out_dict - def _compute_conv_integral(self, tensor_dict): - new_tensors = {} + def _compute_conv_integral( + self, tensor_dict: Dict[str, TensorMap] + ) -> Dict[str, TensorMap]: + """ + Compute the O(3)-convolution of each tensor in ``tensor_dict`` with O(3) + characters. + + :param tensor_dict: dictionary of TensorMaps to compute convolution integral for + :return: dictionary of TensorMaps with convolution integrals + """ + new_tensors: Dict[str, TensorMap] = {} for name, tensor in tensor_dict.items(): keys = tensor.keys remaining_keys = Labels( keys.names[:-1], keys.values[keys.column("inversion") == 1][:, :-1], ) - new_blocks = [] - new_keys = [] - for key in remaining_keys: - key_vals = { - k: int(v) for k, v in zip(key.names, key.values, strict=True) - } - so3_block = tensor.block(key_vals | {"inversion": 1}) - pso3_block = tensor.block(key_vals | {"inversion": -1}) + new_blocks: List[TensorBlock] = [] + new_keys: List[torch.Tensor] = [] + for key_names, key_values in zip( + remaining_keys.names, remaining_keys.values, strict=True + ): + key_to_match_plus: Dict[str, int] = {} + key_to_match_minus: Dict[str, int] = {} + for k, v in zip(key_names, key_values, strict=True): + key_to_match_plus[k] = int(v) + key_to_match_minus[k] = int(v) + key_to_match_plus["inversion"] = 1 + key_to_match_minus["inversion"] = -1 + so3_block = tensor.block(key_to_match_plus) + pso3_block = tensor.block(key_to_match_minus) for o3_lambda in range(self.max_o3_lambda_character + 1): so3_chi = self.so3_characters[o3_lambda] @@ -583,7 +677,7 @@ def _compute_conv_integral(self, tensor_dict): so3_chi, pso3_block, pso3_block, self.so3_weights ) for o3_sigma in [1, -1]: - pso3_chi = self.pso3_characters[(o3_lambda, o3_sigma)] + pso3_chi = self.pso3_characters[f"{o3_lambda}_{o3_sigma}"] third_term = _character_convolution( pso3_chi, so3_block, pso3_block, self.so3_weights ) @@ -601,9 +695,12 @@ def _compute_conv_integral(self, tensor_dict): ) new_blocks.append(block) new_keys.append( - torch.cat([key.values, torch.tensor([o3_lambda, o3_sigma])]) + torch.cat([key_values, torch.tensor([o3_lambda, o3_sigma])]) ) - key_names = [name for name in tensor.keys.names if name != "inversion"] + key_names: List[str] = [] + for key_name in tensor.keys.names: + if key_name != "inversion": + key_names.append(key_name) new_tensor = TensorMap( Labels( key_names + ["chi_lambda", "chi_sigma"], @@ -617,35 +714,44 @@ def _compute_conv_integral(self, tensor_dict): def _compute_mean_and_variance( self, tensor_dict: Dict[str, TensorMap], - ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: - mean_var = {} + ) -> Dict[str, TensorMap]: + """ + Compute the mean and variance of the outputs over O(3). + + :param tensor_dict: dictionary of TensorMaps to compute mean and variance for + :return: dictionary of TensorMaps with mean and variance + """ + mean_var: Dict[str, TensorMap] = {} for name in tensor_dict: tensor = tensor_dict[name] - mean_blocks = [] - second_moment_blocks = [] - mean_norm_blocks = [] - for block in tensor: + mean_blocks: List[TensorBlock] = [] + second_moment_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): rot_ids = block.samples.column("so3_rotation") values = block.values - values_norm = ( - torch.norm(values, dim=tuple(range(1, values.ndim - 1))) - if values.ndim > 2 - else torch.abs(values) - ) - values_squared = values_norm**2 + if values.ndim > 2: + dims: List[int] = [] + for i in range(1, values.ndim - 1): + dims.append(i) + values_squared = torch.sum(values**2, dim=dims) + else: + values_squared = values**2 - view = (values.size(0), *[1] * (values.ndim - 1)) + view: List[int] = [] + view.append(values.size(0)) + for _ in range(values.ndim - 1): + view.append(1) values = 0.5 * self.so3_weights[rot_ids].view(view) * values - view = (values_squared.size(0), *[1] * (values_squared.ndim - 1)) + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) values_squared = ( 0.5 * self.so3_weights[rot_ids].view(view) * values_squared ) - view = (values_norm.size(0), *[1] * (values_norm.ndim - 1)) - values_norm = 0.5 * self.so3_weights[rot_ids].view(view) * values_norm - mean_blocks.append( TensorBlock( values=values, @@ -654,14 +760,6 @@ def _compute_mean_and_variance( properties=block.properties, ) ) - mean_norm_blocks.append( - TensorBlock( - values=values_norm, - samples=block.samples, - components=[], - properties=block.properties, - ) - ) second_moment_blocks.append( TensorBlock( values=values_squared, @@ -678,10 +776,26 @@ def _compute_mean_and_variance( ) # Mean norm - tensor_mean_norm = TensorMap(tensor.keys, mean_norm_blocks) - tensor_mean_norm = mts.sum_over_samples( - tensor_mean_norm.keys_to_samples("inversion"), - ["inversion", "so3_rotation"], + mean_norm_squared_blocks: List[TensorBlock] = [] + for block in tensor_mean.blocks(): + vals = block.values + if vals.ndim > 2: + dims: List[int] = [] + for i in range(1, vals.ndim - 1): + dims.append(i) + vals = torch.sum(vals**2, dim=dims) + else: + vals = vals**2 + mean_norm_squared_blocks.append( + TensorBlock( + values=vals, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + tensor_mean_norm_squared = TensorMap( + tensor_mean.keys, mean_norm_squared_blocks ) # Second moment @@ -693,11 +807,10 @@ def _compute_mean_and_variance( # Variance tensor_variance = mts.subtract( - tensor_second_moment, mts.pow(tensor_mean_norm, 2) + tensor_second_moment, tensor_mean_norm_squared ) mean_var[name + "_mean"] = tensor_mean - mean_var[name + "_mean_norm"] = tensor_mean_norm mean_var[name + "_var"] = tensor_variance return mean_var @@ -706,8 +819,7 @@ def _eval_over_grid( systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels], - return_tensormaps: bool = True, - ): + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: """ Sample the model on the O(3) quadrature. @@ -721,12 +833,25 @@ def _eval_over_grid( device = systems[0].positions.device dtype = systems[0].positions.dtype - transformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { - name: [{-1: None, 1: None} for _ in systems] for name in outputs - } - backtransformed_outputs: Dict[str, List[Dict[int, Optional[TensorMap]]]] = { - name: [{-1: None, 1: None} for _ in systems] for name in outputs - } + transformed_outputs = torch.jit.annotate( + Dict[str, List[Dict[int, TensorMap]]], {} + ) + for name in outputs: + lst = torch.jit.annotate(List[Dict[int, TensorMap]], []) + for _ in systems: + d = torch.jit.annotate(Dict[int, TensorMap], {}) + lst.append(d) + transformed_outputs[name] = lst + backtransformed_outputs = torch.jit.annotate( + Dict[str, List[Dict[int, TensorMap]]], {} + ) + for name in outputs: + lst = torch.jit.annotate(List[Dict[int, TensorMap]], []) + for _ in systems: + d = torch.jit.annotate(Dict[int, TensorMap], {}) + lst.append(d) + backtransformed_outputs[name] = lst + for i_sys, system in enumerate(systems): for inversion in [-1, 1]: rotation_outputs: List[Dict[str, TensorMap]] = [] @@ -745,17 +870,17 @@ def _eval_over_grid( rotation_outputs.append(out) # Combine batch outputs - for name in transformed_outputs: - combined: List[TensorMap] = [r[name] for r in rotation_outputs] + for name in outputs: + combined_: List[TensorMap] = [r[name] for r in rotation_outputs] combined = mts.join( - combined, + combined_, "samples", add_dimension="batch_rotation", ) if "batch_rotation" in combined[0].samples.names: # Reindex - blocks = [] - for block in combined: + blocks: List[TensorBlock] = [] + for block in combined.blocks(): batch_id = block.samples.column("batch_rotation") rot_id = block.samples.column("system") new_sample_values = block.samples.values[:, :-1] @@ -803,94 +928,113 @@ def _eval_over_grid( name ] - if return_tensormaps: - # Massage outputs to have desired shape - for name in transformed_outputs: - joined_plus = mts.join( - [ - transformed_outputs[name][i_sys][1] - for i_sys in range(len(systems)) - ], - "samples", - add_dimension="phys_system", - ) - joined_minus = mts.join( - [ - transformed_outputs[name][i_sys][1] - for i_sys in range(len(systems)) - ], - "samples", - add_dimension="phys_system", - ) - joined = mts.join( - [ - mts.append_dimension(joined_plus, "keys", "inversion", 1), - mts.append_dimension(joined_minus, "keys", "inversion", -1), - ], - "samples", - different_keys="union", - ) + transformed_outputs_tensor: Dict[str, TensorMap] = {} + backtransformed_outputs_tensor: Dict[str, TensorMap] = {} + # Massage outputs to have desired shape + for name in transformed_outputs: + joined_plus = mts.join( + [transformed_outputs[name][i_sys][1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [transformed_outputs[name][i_sys][1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") + + if "phys_system" in joined[0].samples.names: joined = mts.rename_dimension( - joined, "samples", "system", "so3_rotation" + joined, "samples", "phys_system", "system" ) - - if "phys_system" in joined[0].samples.names: - joined = mts.rename_dimension( - joined, "samples", "phys_system", "system" - ) - else: - joined = mts.insert_dimension( - joined, - "samples", - 1, - "system", - torch.zeros( - joined[0].samples.values.shape[0], dtype=torch.long - ), - ) - transformed_outputs[name] = joined - - joined_plus = mts.join( - [ - backtransformed_outputs[name][i_sys][1] - for i_sys in range(len(systems)) - ], + else: + joined = mts.insert_dimension( + joined, "samples", - add_dimension="phys_system", + 1, + "system", + torch.zeros(joined[0].samples.values.shape[0], dtype=torch.long), ) - joined_minus = mts.join( - [ - backtransformed_outputs[name][i_sys][1] - for i_sys in range(len(systems)) - ], - "samples", - add_dimension="phys_system", + if "atom" in joined[0].samples.names: + perm = _permute_system_before_atom(joined[0].samples.names) + joined = mts.permute_dimensions(joined, "samples", perm) + transformed_outputs_tensor[name] = joined + + joined_plus = mts.join( + [ + backtransformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [ + backtransformed_outputs[name][i_sys][1] + for i_sys in range(len(systems)) + ], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension( + joined, "samples", "phys_system", "system" ) - joined = mts.join( - [ - mts.append_dimension(joined_plus, "keys", "inversion", 1), - mts.append_dimension(joined_minus, "keys", "inversion", -1), - ], + else: + joined = mts.insert_dimension( + joined, "samples", - different_keys="union", - ) - joined = mts.rename_dimension( - joined, "samples", "system", "so3_rotation" + 1, + "system", + torch.zeros(joined[0].samples.values.shape[0], dtype=torch.long), ) - if "phys_system" in joined[0].samples.names: - joined = mts.rename_dimension( - joined, "samples", "phys_system", "system" - ) - else: - joined = mts.insert_dimension( - joined, - "samples", - 1, - "system", - torch.zeros( - joined[0].samples.values.shape[0], dtype=torch.long - ), - ) - backtransformed_outputs[name] = joined - - return transformed_outputs, backtransformed_outputs + if "atom" in joined[0].samples.names: + perm = _permute_system_before_atom(joined[0].samples.names) + joined = mts.permute_dimensions(joined, "samples", perm) + backtransformed_outputs_tensor[name] = joined + + return transformed_outputs_tensor, backtransformed_outputs_tensor + + +def _permute_system_before_atom(labels: List[str]) -> List[int]: + # find positions + sys_idx = -1 + atom_idx = -1 + for i in range(len(labels)): + if labels[i] == "system": + sys_idx = i + elif labels[i] == "atom": + atom_idx = i + + # identity permutation + perm = list(range(len(labels))) + + # reorder only if both present and system is after atom + if sys_idx != -1 and atom_idx != -1 and sys_idx > atom_idx: + v = perm[sys_idx] + # remove system + for k in range(sys_idx, len(perm) - 1): + perm[k] = perm[k + 1] + perm.pop() + # insert before atom + perm.insert(atom_idx, v) + + return perm From 7695dfb240be3427828da7fbe7d22abcc6bdd82a Mon Sep 17 00:00:00 2001 From: ppegolo Date: Wed, 19 Nov 2025 18:18:01 +0100 Subject: [PATCH 41/57] Fix bug --- .../metatomic/torch/ase_calculator.py | 6 +- .../metatomic/torch/symmetrized_model.py | 211 ++++++++++++++---- 2 files changed, 167 insertions(+), 50 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index f5d00593..b28e2920 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -227,9 +227,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index a3f9ff86..b04b5cc7 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -425,6 +425,7 @@ def _character_convolution( Then contract with another block. """ samples = block1.samples + assert samples.names[0] == "so3_rotation" components = block1.components properties = block1.properties values = block1.values @@ -432,26 +433,36 @@ def _character_convolution( n_rot = chi.size(1) weight = 0.5 * w.to(dtype=values.dtype, device=values.device) + # reshape the values to separate rotations from the other samples new_shape = [n_rot, -1] + list(values.shape[1:]) reshaped_values = values.reshape(new_shape) + + # broadcast weights to match reshaped_values view: List[int] = [] view.append(-1) for _ in range(reshaped_values.ndim - 1): view.append(1) weighted_values = weight.view(view) * reshaped_values + + # broadcast characters to match reshaped_values contracted_shape: List[int] = [chi.shape[0]] + list(weighted_values.shape[1:]) contracted_values = ( chi @ weighted_values.reshape(weighted_values.shape[0], -1) ).reshape(contracted_shape) values2 = block2.values + # reshape the values to separate rotations from the other samples new_shape = [n_rot, -1] + list(values2.shape[1:]) reshaped_values2 = values2.reshape(new_shape) + + # broadcast weights to match reshaped_values2 view: List[int] = [] view.append(-1) for _ in range(reshaped_values2.ndim - 1): view.append(1) weighted_values2 = weight.view(view) * reshaped_values2 + + # contract weighted_values2 with contracted_values contracted_values = torch.einsum( "i...,i...->...", weighted_values2, @@ -557,40 +568,130 @@ def __init__( ): super().__init__() self.base_model = base_model + + try: + ref_param = next(base_model.parameters()) + dtype = ref_param.dtype + device = ref_param.device + except StopIteration: + dtype = torch.get_default_dtype() + device = torch.device("cpu") + self.max_o3_lambda = max_o3_lambda self.batch_size = batch_size if max_o3_lambda_character is None: max_o3_lambda_character = max_o3_lambda self.max_o3_lambda_character = max_o3_lambda_character - # Compute grid + # Compute grid (unchanged) lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda) alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( lebedev_order, n_inplane_rotations ) - self.so3_weights = torch.from_numpy(w_so3) + so3_weights = torch.from_numpy(w_so3).to(device=device, dtype=dtype) + self.register_buffer("so3_weights", so3_weights) - # Active rotations - self.so3_rotations = torch.from_numpy( + so3_rotations = torch.from_numpy( _rotations_from_angles(alpha, beta, gamma).as_matrix() - ) + ).to(device=device, dtype=dtype) + self.register_buffer("so3_rotations", so3_rotations) self.n_so3_rotations = self.so3_rotations.size(0) - # Compute inverse Wigner D representations angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) - self.so3_inverse_rotations = torch.from_numpy( + so3_inverse_rotations = torch.from_numpy( _rotations_from_angles(*angles_inverse_rotations).as_matrix() - ) - self.wigner_D_inverse_rotations = _compute_real_wigner_matrices( + ).to(device=device, dtype=dtype) + self.register_buffer("so3_inverse_rotations", so3_inverse_rotations) + + self._wigner_D_inverse_jit: Dict[int, torch.Tensor] = {} + self._so3_characters_jit: Dict[int, torch.Tensor] = {} + self._pso3_characters_jit: Dict[str, torch.Tensor] = {} + # Since Wigner D matrices are stored in dicts, we need a bit of gymnastics to + # register the buffers + raw_wigner = _compute_real_wigner_matrices( self.max_o3_lambda, angles_inverse_rotations ) + self._wigner_D_inverse_names: Dict[int, str] = {} + for ell, D in raw_wigner.items(): + if isinstance(D, np.ndarray): + D = torch.from_numpy(D) + D = D.to(dtype=dtype, device=device) + name = f"wigner_D_inverse_rotations_l{ell}" + self.register_buffer(name, D) + self._wigner_D_inverse_names[ell] = name + # TorchScript dict view uses the same tensor + self._wigner_D_inverse_jit[ell] = D # Compute characters - self.so3_characters, self.pso3_characters = compute_characters( + so3_characters, pso3_characters = compute_characters( self.max_o3_lambda_character, (alpha, beta, gamma), angles_inverse_rotations, ) + self._so3_char_names: Dict[int, str] = {} + self._pso3_char_names: Dict[str, str] = {} + + # Since characters are stored in dicts, we need a bit of gymnastics to + # register the buffers + for ell, ch in so3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + ch = ch.to(dtype=dtype, device=device) + name = f"so3_characters_l{ell}" + self.register_buffer(name, ch) + self._so3_char_names[ell] = name + self._so3_characters_jit[ell] = ch + + for ell, ch in pso3_characters.items(): + # here `ell` is your combined "lambda_sigma" string, e.g. "0_+1" + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + ch = ch.to(dtype=dtype, device=device) + name = f"pso3_characters_l{ell}" + self.register_buffer(name, ch) + self._pso3_char_names[ell] = name + self._pso3_characters_jit[ell] = ch + + @torch.jit.ignore + def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: + return { + ell: getattr(self, name) + for ell, name in self._wigner_D_inverse_names.items() + } + + @property + def wigner_D_inverse_rotations(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._wigner_D_inverse_dict() + + @torch.jit.ignore + def _so3_characters_dict(self) -> Dict[int, torch.Tensor]: + return {ell: getattr(self, name) for ell, name in self._so3_char_names.items()} + + @property + def so3_characters(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._so3_characters_dict() + + @torch.jit.ignore + def _pso3_characters_dict(self) -> Dict[str, torch.Tensor]: + return {key: getattr(self, name) for key, name in self._pso3_char_names.items()} + + @property + def pso3_characters(self) -> Dict[str, torch.Tensor]: + # Python-only nice view + return self._pso3_characters_dict() + + def _get_wigner_D_inverse(self, ell: int) -> torch.Tensor: + return self._wigner_D_inverse_jit[ell] + + def _get_so3_character(self, o3_lambda: int) -> torch.Tensor: + return self._so3_characters_jit[o3_lambda] + + def _get_pso3_character(self, o3_lambda: int, o3_sigma: int) -> torch.Tensor: + # TorchScript-safe label build + label = str(o3_lambda) + "_" + str(o3_sigma) + return self._pso3_characters_jit[label] def forward( self, @@ -612,27 +713,19 @@ def forward( systems, outputs, selected_atoms ) + # Compute the O(3) mean and variance mean_var = self._compute_mean_and_variance(backtransformed_outputs) - convolution_integrals = self._compute_conv_integral(transformed_outputs) + + # return mean_var, transformed_outputs # Compute the character projections - out_dict = mean_var + convolution_integrals = self._compute_conv_integral(transformed_outputs) + + out_dict: Dict[str, TensorMap] = {} + for name, tensor in mean_var.items(): + out_dict[name] = tensor for name, integral in convolution_integrals.items(): - mean_tensor = mean_var[name + "_mean"] - for key, block in integral.items(): - mean_block = mean_tensor.block( - { - k: int(v) - for k, v in zip( - mean_tensor.keys.names, key.values, strict=False - ) - } - ) - assert block.values.shape == mean_block.values.shape - integral[key].values[:] = mean_block.values - block.values - if "_" in integral.keys.names: - integral = mts.remove_dimension(integral, "keys", "_") - out_dict[name + "_character_projections"] = integral + out_dict[name] = integral return out_dict @@ -646,7 +739,9 @@ def _compute_conv_integral( :param tensor_dict: dictionary of TensorMaps to compute convolution integral for :return: dictionary of TensorMaps with convolution integrals """ + new_tensors: Dict[str, TensorMap] = {} + # loop over tensormaps for name, tensor in tensor_dict.items(): keys = tensor.keys remaining_keys = Labels( @@ -655,21 +750,22 @@ def _compute_conv_integral( ) new_blocks: List[TensorBlock] = [] new_keys: List[torch.Tensor] = [] - for key_names, key_values in zip( - remaining_keys.names, remaining_keys.values, strict=True - ): + # loop over keys in the final tensormap + for key_values in remaining_keys.values: key_to_match_plus: Dict[str, int] = {} key_to_match_minus: Dict[str, int] = {} - for k, v in zip(key_names, key_values, strict=True): + for k, v in zip(remaining_keys.names, key_values, strict=True): key_to_match_plus[k] = int(v) key_to_match_minus[k] = int(v) key_to_match_plus["inversion"] = 1 key_to_match_minus["inversion"] = -1 + # get the corresponding blocks for proper and improper rotations so3_block = tensor.block(key_to_match_plus) pso3_block = tensor.block(key_to_match_minus) + # loop over SO(3) irreps for o3_lambda in range(self.max_o3_lambda_character + 1): - so3_chi = self.so3_characters[o3_lambda] + so3_chi = self._get_so3_character(o3_lambda) first_term = _character_convolution( so3_chi, so3_block, so3_block, self.so3_weights ) @@ -677,11 +773,10 @@ def _compute_conv_integral( so3_chi, pso3_block, pso3_block, self.so3_weights ) for o3_sigma in [1, -1]: - pso3_chi = self.pso3_characters[f"{o3_lambda}_{o3_sigma}"] + pso3_chi = self._get_pso3_character(o3_lambda, o3_sigma) third_term = _character_convolution( - pso3_chi, so3_block, pso3_block, self.so3_weights + pso3_chi, pso3_block, so3_block, self.so3_weights ) - block = TensorBlock( samples=first_term.samples, components=first_term.components, @@ -695,7 +790,16 @@ def _compute_conv_integral( ) new_blocks.append(block) new_keys.append( - torch.cat([key_values, torch.tensor([o3_lambda, o3_sigma])]) + torch.cat( + [ + key_values, + torch.tensor( + [o3_lambda, o3_sigma], + device=key_values.device, + dtype=key_values.dtype, + ), + ] + ) ) key_names: List[str] = [] for key_name in tensor.keys.names: @@ -708,7 +812,9 @@ def _compute_conv_integral( ), new_blocks, ) - new_tensors[name] = new_tensor + if "_" in new_tensor.keys.names: + new_tensor = mts.remove_dimension(new_tensor, "keys", "_") + new_tensors[name + "_character_projection"] = new_tensor return new_tensors def _compute_mean_and_variance( @@ -906,6 +1012,14 @@ def _eval_over_grid( for i_sys, system in enumerate(systems): for inversion in [-1, 1]: tensor = transformed_outputs[name][i_sys][inversion] + wigner_dict = torch.jit.annotate(Dict[int, List[torch.Tensor]], {}) + for ell in self._wigner_D_inverse_jit: + wigner_dict[ell] = ( + self._get_wigner_D_inverse(ell) + .to(device=device, dtype=dtype) + .unbind(0) + ) + _, backtransformed, _ = _apply_augmentations( [system] * n_rot, {name: tensor}, @@ -917,12 +1031,7 @@ def _eval_over_grid( * inversion ).unbind(0) ), - { - ell: self.wigner_D_inverse_rotations[ell] - .to(device=device, dtype=dtype) - .unbind(0) - for ell in self.wigner_D_inverse_rotations - }, + wigner_dict, ) backtransformed_outputs[name][i_sys][inversion] = backtransformed[ name @@ -938,7 +1047,7 @@ def _eval_over_grid( add_dimension="phys_system", ) joined_minus = mts.join( - [transformed_outputs[name][i_sys][1] for i_sys in range(len(systems))], + [transformed_outputs[name][i_sys][-1] for i_sys in range(len(systems))], "samples", add_dimension="phys_system", ) @@ -962,7 +1071,11 @@ def _eval_over_grid( "samples", 1, "system", - torch.zeros(joined[0].samples.values.shape[0], dtype=torch.long), + torch.zeros( + joined[0].samples.values.shape[0], + dtype=torch.long, + device=joined[0].samples.values.device, + ), ) if "atom" in joined[0].samples.names: perm = _permute_system_before_atom(joined[0].samples.names) @@ -979,7 +1092,7 @@ def _eval_over_grid( ) joined_minus = mts.join( [ - backtransformed_outputs[name][i_sys][1] + backtransformed_outputs[name][i_sys][-1] for i_sys in range(len(systems)) ], "samples", @@ -1004,7 +1117,11 @@ def _eval_over_grid( "samples", 1, "system", - torch.zeros(joined[0].samples.values.shape[0], dtype=torch.long), + torch.zeros( + joined[0].samples.values.shape[0], + dtype=torch.long, + device=joined[0].samples.values.device, + ), ) if "atom" in joined[0].samples.names: perm = _permute_system_before_atom(joined[0].samples.names) From 6111a5b29d531ea87efc5917972cccd8935f1051 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 21 Nov 2025 13:21:07 +0100 Subject: [PATCH 42/57] detach grads --- .../metatomic/torch/symmetrized_model.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index b04b5cc7..8654ac47 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -536,7 +536,7 @@ class SymmetrizedModel(torch.nn.Module): .. math:: \| P_{\ell,\sigma} x \|^{2} - \;=\; + = \left\langle \, | (P_{\ell,\sigma} x)(g) |^{2} \, \right\rangle_{O(3)} . These quantities describe how the model output is distributed across the @@ -571,11 +571,11 @@ def __init__( try: ref_param = next(base_model.parameters()) - dtype = ref_param.dtype device = ref_param.device + dtype = ref_param.dtype except StopIteration: - dtype = torch.get_default_dtype() device = torch.device("cpu") + dtype = torch.get_default_dtype() self.max_o3_lambda = max_o3_lambda self.batch_size = batch_size @@ -968,11 +968,12 @@ def _eval_over_grid( ) for R in self.so3_rotations[batch : batch + self.batch_size] ] - out = self.base_model( - transformed_systems, - outputs, - selected_atoms, - ) + with torch.no_grad(): + out = self.base_model( + transformed_systems, + outputs, + selected_atoms, + ) rotation_outputs.append(out) # Combine batch outputs @@ -995,7 +996,7 @@ def _eval_over_grid( ) blocks.append( TensorBlock( - values=block.values, + values=block.values.detach(), samples=Labels( block.samples.names[:-1], new_sample_values, From 31473110b77359c123e6658ae05a78fb61cf6b4c Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 21 Nov 2025 14:42:17 +0100 Subject: [PATCH 43/57] update --- .../metatomic/torch/symmetrized_model.py | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 8654ac47..211067c2 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -713,15 +713,18 @@ def forward( systems, outputs, selected_atoms ) + # Compute norms + norms = self._compute_norm_per_property(transformed_outputs) + # Compute the O(3) mean and variance mean_var = self._compute_mean_and_variance(backtransformed_outputs) - # return mean_var, transformed_outputs - # Compute the character projections convolution_integrals = self._compute_conv_integral(transformed_outputs) out_dict: Dict[str, TensorMap] = {} + for name, tensor in norms.items(): + out_dict[name] = tensor for name, tensor in mean_var.items(): out_dict[name] = tensor for name, integral in convolution_integrals.items(): @@ -729,6 +732,49 @@ def forward( return out_dict + def _compute_norm_per_property( + self, tensor_dict: Dict[str, TensorMap] + ) -> Dict[str, TensorMap]: + """ + Compute the norm per property of each tensor in ``tensor_dict``. + + :param tensor_dict: dictionary of TensorMaps to compute norms for + :return: dictionary of TensorMaps with norms per property + """ + norms: Dict[str, TensorMap] = {} + for name in tensor_dict: + tensor = tensor_dict[name] + norm_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values_squared = block.values**2 + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = ( + 0.5 * self.so3_weights[rot_ids].view(view) * values_squared + ) + + norm_blocks.append( + TensorBlock( + values=values_squared, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + + tensor_norm = TensorMap(tensor.keys, norm_blocks) + tensor_norm = mts.sum_over_samples( + tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + norms[name + "_norm"] = tensor_norm + return norms + def _compute_conv_integral( self, tensor_dict: Dict[str, TensorMap] ) -> Dict[str, TensorMap]: From fc28cd2a3a714e6114b86a89eec9ac4194ba9393 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 21 Nov 2025 14:55:35 +0100 Subject: [PATCH 44/57] fix normalization --- .../metatomic/torch/symmetrized_model.py | 29 ++++--------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 211067c2..c6914904 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -370,24 +370,6 @@ def _get_so3_character( return chi -def _get_o3_character( - alphas: np.ndarray, - betas: np.ndarray, - gammas: np.ndarray, - o3_lambda: int, - o3_sigma: int, - tol: float = 1e-13, -) -> np.ndarray: - """ - Numerically stable evaluation of the character function χ_{o3_lambda}(R) over O(3). - """ - return ( - o3_sigma - * ((-1) ** o3_lambda) - * _get_so3_character(alphas, betas, gammas, o3_lambda, tol) - ) - - def compute_characters( o3_lambda_max: int, angles: Tuple[np.ndarray, np.ndarray, np.ndarray], @@ -431,7 +413,7 @@ def _character_convolution( values = block1.values chi = chi.to(dtype=values.dtype, device=values.device) n_rot = chi.size(1) - weight = 0.5 * w.to(dtype=values.dtype, device=values.device) + weight = w.to(dtype=values.dtype, device=values.device) # reshape the values to separate rotations from the other samples new_shape = [n_rot, -1] + list(values.shape[1:]) @@ -760,7 +742,7 @@ def _compute_norm_per_property( norm_blocks.append( TensorBlock( - values=values_squared, + values=values_squared, # /(8 * torch.pi**2), samples=block.samples, components=block.components, properties=block.properties, @@ -772,7 +754,7 @@ def _compute_norm_per_property( tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] ) - norms[name + "_norm"] = tensor_norm + norms[name + "_squared_norm"] = tensor_norm return norms def _compute_conv_integral( @@ -831,8 +813,8 @@ def _compute_conv_integral( 0.25 * (first_term.values + second_term.values) + 0.5 * third_term.values ) - * (2 * o3_lambda + 1) - / (8 * torch.pi**2) ** 2, + * (2 * o3_lambda + 1), + # / (8 * torch.pi**2) ** 2, ) new_blocks.append(block) new_keys.append( @@ -963,6 +945,7 @@ def _compute_mean_and_variance( ) mean_var[name + "_mean"] = tensor_mean + mean_var[name + "_norm_squared"] = tensor_second_moment mean_var[name + "_var"] = tensor_variance return mean_var From 3ac3b2154e954ce220ca1fe5905d1377713990d1 Mon Sep 17 00:00:00 2001 From: MichelangeloDomina Date: Tue, 25 Nov 2025 13:54:17 +0100 Subject: [PATCH 45/57] - Added flag for target - Characters now are evaluated on cpu and streamed on gpu (usage now requires passing a model already on the desired device to the SymmetrizedModel class - Removed anything with the name "features" from the evaluation of mean and variances in `_compute_mean_and_variance` --- .../metatomic/torch/symmetrized_model.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index c6914904..1e585a33 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -544,7 +544,8 @@ class SymmetrizedModel(torch.nn.Module): def __init__( self, base_model, - max_o3_lambda, + max_o3_lambda_grid, + max_o3_lambda_target, batch_size: int = 32, max_o3_lambda_character: Optional[int] = None, ): @@ -559,14 +560,15 @@ def __init__( device = torch.device("cpu") dtype = torch.get_default_dtype() - self.max_o3_lambda = max_o3_lambda + self.max_o3_lambda_grid = max_o3_lambda_grid + self.max_o3_lambda_target = max_o3_lambda_target self.batch_size = batch_size if max_o3_lambda_character is None: - max_o3_lambda_character = max_o3_lambda + max_o3_lambda_character = max_o3_lambda_grid self.max_o3_lambda_character = max_o3_lambda_character # Compute grid (unchanged) - lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda) + lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( lebedev_order, n_inplane_rotations ) @@ -591,7 +593,7 @@ def __init__( # Since Wigner D matrices are stored in dicts, we need a bit of gymnastics to # register the buffers raw_wigner = _compute_real_wigner_matrices( - self.max_o3_lambda, angles_inverse_rotations + self.max_o3_lambda_target, angles_inverse_rotations ) self._wigner_D_inverse_names: Dict[int, str] = {} for ell, D in raw_wigner.items(): @@ -618,21 +620,24 @@ def __init__( for ell, ch in so3_characters.items(): if isinstance(ch, np.ndarray): ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device=device) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU name = f"so3_characters_l{ell}" self.register_buffer(name, ch) self._so3_char_names[ell] = name - self._so3_characters_jit[ell] = ch + + self._so3_characters_jit = {} # kill the CUDA dict cache for ell, ch in pso3_characters.items(): - # here `ell` is your combined "lambda_sigma" string, e.g. "0_+1" if isinstance(ch, np.ndarray): ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device=device) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU name = f"pso3_characters_l{ell}" self.register_buffer(name, ch) self._pso3_char_names[ell] = name - self._pso3_characters_jit[ell] = ch + + self._pso3_characters_jit = {} @torch.jit.ignore def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: @@ -668,12 +673,34 @@ def _get_wigner_D_inverse(self, ell: int) -> torch.Tensor: return self._wigner_D_inverse_jit[ell] def _get_so3_character(self, o3_lambda: int) -> torch.Tensor: - return self._so3_characters_jit[o3_lambda] + name = self._so3_char_names[o3_lambda] + ch_cpu = getattr(self, name) + + # follow the base model device/dtype + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) def _get_pso3_character(self, o3_lambda: int, o3_sigma: int) -> torch.Tensor: - # TorchScript-safe label build label = str(o3_lambda) + "_" + str(o3_sigma) - return self._pso3_characters_jit[label] + name = self._pso3_char_names[label] + ch_cpu = getattr(self, name) + + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) def forward( self, @@ -857,6 +884,8 @@ def _compute_mean_and_variance( """ mean_var: Dict[str, TensorMap] = {} for name in tensor_dict: + if "features" in name: + continue tensor = tensor_dict[name] mean_blocks: List[TensorBlock] = [] second_moment_blocks: List[TensorBlock] = [] From f1fcac5b6440528c3653d2124f5c3bd501815bf3 Mon Sep 17 00:00:00 2001 From: MichelangeloDomina Date: Wed, 26 Nov 2025 11:21:51 +0100 Subject: [PATCH 46/57] Added decomposition of the stress into l=0 and l=2 components (and TYPE_CHECKING flag) --- .../metatomic/torch/symmetrized_model.py | 124 +++++++++++++++++- 1 file changed, 120 insertions(+), 4 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 1e585a33..bb5f7702 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -1,12 +1,26 @@ -from typing import Dict, List, Optional, Tuple - +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING import metatensor.torch as mts + +if TYPE_CHECKING: + class TensorBlock: + ... + class System: + ... + class TensorMap: + ... + class ModelOutput: + ... + class Labels: + ... +else: + from metatensor.torch import Labels, TensorBlock, TensorMap + from metatomic.torch import ModelOutput, System + import numpy as np import torch -from metatensor.torch import Labels, TensorBlock, TensorMap from metatrain.utils.augmentation import _apply_augmentations -from metatomic.torch import ModelOutput, System, register_autograd_neighbors +from metatomic.torch import register_autograd_neighbors try: @@ -298,6 +312,55 @@ def _angles_from_rotations( gammas = gammas.reshape(batch_shape) return alphas, betas, gammas +def _l0_components_from_matrices( + A: torch.Tensor +) -> torch.Tensor: + """ + Extract the L=0 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=0 components to have 1 component in the last dimension + l0_A = torch.empty(A.shape[:-2] + (1,), + dtype=A.dtype, + device=A.device) + + # Compute the L=0 component as the trace of A + l0_A[..., 0] = (A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2]) + + l0_A = l0_A.permute(0, 2, 1) + return l0_A + + +def _l2_components_from_matrices( + A: torch.Tensor +) -> torch.Tensor: + """ + Extract the L=2 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=2 components to have 5 components in the last dimension + l2_A = torch.empty(A.shape[:-2] + (5,), + dtype=A.dtype, + device=A.device) + + l2_A[..., 0] = (A[..., 0, 1] + A[..., 1, 0]) / 2.0 + l2_A[..., 1] = (A[..., 1, 2] + A[..., 2, 1]) / 2.0 + l2_A[..., 2] = (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / ((2.0) * np.sqrt(3.0)) + l2_A[..., 3] = (A[..., 0, 2] + A[..., 2, 0]) / 2.0 + l2_A[..., 4] = (A[..., 0, 0] - A[..., 1, 1]) / 2.0 + + l2_A = l2_A.permute(0, 2, 1) + + return l2_A + def _euler_angles_of_combined_rotation( angles1: Tuple[np.ndarray, np.ndarray, np.ndarray], @@ -722,6 +785,9 @@ def forward( systems, outputs, selected_atoms ) + transformed_outputs = self._decompose_stress_tensor(transformed_outputs) + backtransformed_outputs = self._decompose_stress_tensor(backtransformed_outputs) + # Compute norms norms = self._compute_norm_per_property(transformed_outputs) @@ -741,6 +807,56 @@ def forward( return out_dict + def _decompose_stress_tensor( + self, tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose stress tensor into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed stress tensors + """ + if "non_conservative_stress" not in tensor_dict: + return tensor_dict + else: + tensor = tensor_dict["non_conservative_stress"] + blocks_l0 = [] + blocks_l2 = [] + for block in tensor.blocks(): + + trace_values = _l0_components_from_matrices(block.values) + block_l0 = TensorBlock( + values=trace_values, + samples=block.samples, + components=[Labels( + names="o3_mu", + values=torch.tensor([[0]], device=block.values.device, dtype=torch.int32) + )], + properties=block.properties, + ) + blocks_l0.append(block_l0) + + block_l2 = TensorBlock( + values=_l2_components_from_matrices(block.values), + samples=block.samples, + components=[Labels( + names="o3_mu", + values=torch.tensor([[mu] for mu in range(-2, 3)], device=block.values.device, dtype=torch.int32) + )], + properties=block.properties, + ) + blocks_l2.append(block_l2) + + tensor_dict["non_conservative_stress_l0"] = TensorMap( + tensor.keys, + blocks_l0) + tensor_dict["non_conservative_stress_l2"] = TensorMap( + tensor.keys, + blocks_l2) + tensor_dict.pop("non_conservative_stress") + return tensor_dict + + def _compute_norm_per_property( self, tensor_dict: Dict[str, TensorMap] ) -> Dict[str, TensorMap]: From 50e39d998a369faf7d3beb19cffa413580e3762a Mon Sep 17 00:00:00 2001 From: MichelangeloDomina Date: Thu, 27 Nov 2025 11:01:43 +0100 Subject: [PATCH 47/57] Added flag for full_diagnostic, check on the character vs lebedev order, and renamed _squared_norm in componentwise_norm_squared --- .../metatomic/torch/symmetrized_model.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index bb5f7702..58e0b1f6 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -607,8 +607,9 @@ class SymmetrizedModel(torch.nn.Module): def __init__( self, base_model, - max_o3_lambda_grid, - max_o3_lambda_target, + max_o3_lambda_grid: int, + max_o3_lambda_target: int, + full_diagnostic: bool = True, batch_size: int = 32, max_o3_lambda_character: Optional[int] = None, ): @@ -623,15 +624,19 @@ def __init__( device = torch.device("cpu") dtype = torch.get_default_dtype() - self.max_o3_lambda_grid = max_o3_lambda_grid + self.full_diagnostic = full_diagnostic + self.max_o3_lambda_target = max_o3_lambda_target self.batch_size = batch_size if max_o3_lambda_character is None: - max_o3_lambda_character = max_o3_lambda_grid + max_o3_lambda_grid = 2 * max_o3_lambda_character + 1 + self.max_o3_lambda_grid = max_o3_lambda_grid self.max_o3_lambda_character = max_o3_lambda_character # Compute grid (unchanged) lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) + if lebedev_order < 2 * self.max_o3_lambda_character: + print("Warning: Lebedev order may be insufficient for character projections.") alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( lebedev_order, n_inplane_rotations ) @@ -788,22 +793,23 @@ def forward( transformed_outputs = self._decompose_stress_tensor(transformed_outputs) backtransformed_outputs = self._decompose_stress_tensor(backtransformed_outputs) - # Compute norms - norms = self._compute_norm_per_property(transformed_outputs) + out_dict: Dict[str, TensorMap] = {} # Compute the O(3) mean and variance mean_var = self._compute_mean_and_variance(backtransformed_outputs) - - # Compute the character projections - convolution_integrals = self._compute_conv_integral(transformed_outputs) - - out_dict: Dict[str, TensorMap] = {} - for name, tensor in norms.items(): - out_dict[name] = tensor for name, tensor in mean_var.items(): out_dict[name] = tensor - for name, integral in convolution_integrals.items(): - out_dict[name] = integral + if self.full_diagnostic: + # Compute norms + norms = self._compute_norm_per_property(transformed_outputs) + for name, tensor in norms.items(): + out_dict[name] = tensor + + # Compute the character projections + convolution_integrals = self._compute_conv_integral(transformed_outputs) + for name, integral in convolution_integrals.items(): + out_dict[name] = integral + return out_dict @@ -897,7 +903,7 @@ def _compute_norm_per_property( tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] ) - norms[name + "_squared_norm"] = tensor_norm + norms[name + "componentwise_norm_squared"] = tensor_norm return norms def _compute_conv_integral( From 4e9173889d5fbe938fcece7d8107f893c6b2f93b Mon Sep 17 00:00:00 2001 From: MichelangeloDomina Date: Thu, 27 Nov 2025 11:13:49 +0100 Subject: [PATCH 48/57] Avoid evaluating the characters when full_diagnostic == False --- .../metatomic/torch/symmetrized_model.py | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 58e0b1f6..0c50db1d 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -675,37 +675,38 @@ def __init__( self._wigner_D_inverse_jit[ell] = D # Compute characters - so3_characters, pso3_characters = compute_characters( - self.max_o3_lambda_character, - (alpha, beta, gamma), - angles_inverse_rotations, - ) - self._so3_char_names: Dict[int, str] = {} - self._pso3_char_names: Dict[str, str] = {} + if self.full_diagnostic: + so3_characters, pso3_characters = compute_characters( + self.max_o3_lambda_character, + (alpha, beta, gamma), + angles_inverse_rotations, + ) + self._so3_char_names: Dict[int, str] = {} + self._pso3_char_names: Dict[str, str] = {} - # Since characters are stored in dicts, we need a bit of gymnastics to - # register the buffers - for ell, ch in so3_characters.items(): - if isinstance(ch, np.ndarray): - ch = torch.from_numpy(ch) + # Since characters are stored in dicts, we need a bit of gymnastics to + # register the buffers + for ell, ch in so3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU - name = f"so3_characters_l{ell}" - self.register_buffer(name, ch) - self._so3_char_names[ell] = name + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"so3_characters_l{ell}" + self.register_buffer(name, ch) + self._so3_char_names[ell] = name - self._so3_characters_jit = {} # kill the CUDA dict cache + self._so3_characters_jit = {} # kill the CUDA dict cache - for ell, ch in pso3_characters.items(): - if isinstance(ch, np.ndarray): - ch = torch.from_numpy(ch) + for ell, ch in pso3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU - name = f"pso3_characters_l{ell}" - self.register_buffer(name, ch) - self._pso3_char_names[ell] = name + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"pso3_characters_l{ell}" + self.register_buffer(name, ch) + self._pso3_char_names[ell] = name - self._pso3_characters_jit = {} + self._pso3_characters_jit = {} @torch.jit.ignore def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: From 3463def7dbb946bf0ec35e58ac7d8776142457d4 Mon Sep 17 00:00:00 2001 From: MichelangeloDomina Date: Thu, 27 Nov 2025 12:19:39 +0100 Subject: [PATCH 49/57] Removed the flag full_diagnostics. Instead created a duplicate for the energy and summed before doing the variance --- .../metatomic/torch/symmetrized_model.py | 88 ++++++++++--------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 0c50db1d..71aa1ddc 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -609,7 +609,6 @@ def __init__( base_model, max_o3_lambda_grid: int, max_o3_lambda_target: int, - full_diagnostic: bool = True, batch_size: int = 32, max_o3_lambda_character: Optional[int] = None, ): @@ -624,8 +623,6 @@ def __init__( device = torch.device("cpu") dtype = torch.get_default_dtype() - self.full_diagnostic = full_diagnostic - self.max_o3_lambda_target = max_o3_lambda_target self.batch_size = batch_size if max_o3_lambda_character is None: @@ -675,38 +672,37 @@ def __init__( self._wigner_D_inverse_jit[ell] = D # Compute characters - if self.full_diagnostic: - so3_characters, pso3_characters = compute_characters( - self.max_o3_lambda_character, - (alpha, beta, gamma), - angles_inverse_rotations, - ) - self._so3_char_names: Dict[int, str] = {} - self._pso3_char_names: Dict[str, str] = {} + so3_characters, pso3_characters = compute_characters( + self.max_o3_lambda_character, + (alpha, beta, gamma), + angles_inverse_rotations, + ) + self._so3_char_names: Dict[int, str] = {} + self._pso3_char_names: Dict[str, str] = {} - # Since characters are stored in dicts, we need a bit of gymnastics to - # register the buffers - for ell, ch in so3_characters.items(): - if isinstance(ch, np.ndarray): - ch = torch.from_numpy(ch) + # Since characters are stored in dicts, we need a bit of gymnastics to + # register the buffers + for ell, ch in so3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU - name = f"so3_characters_l{ell}" - self.register_buffer(name, ch) - self._so3_char_names[ell] = name + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"so3_characters_l{ell}" + self.register_buffer(name, ch) + self._so3_char_names[ell] = name - self._so3_characters_jit = {} # kill the CUDA dict cache + self._so3_characters_jit = {} # kill the CUDA dict cache - for ell, ch in pso3_characters.items(): - if isinstance(ch, np.ndarray): - ch = torch.from_numpy(ch) + for ell, ch in pso3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU - name = f"pso3_characters_l{ell}" - self.register_buffer(name, ch) - self._pso3_char_names[ell] = name + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"pso3_characters_l{ell}" + self.register_buffer(name, ch) + self._pso3_char_names[ell] = name - self._pso3_characters_jit = {} + self._pso3_characters_jit = {} @torch.jit.ignore def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: @@ -800,16 +796,16 @@ def forward( mean_var = self._compute_mean_and_variance(backtransformed_outputs) for name, tensor in mean_var.items(): out_dict[name] = tensor - if self.full_diagnostic: - # Compute norms - norms = self._compute_norm_per_property(transformed_outputs) - for name, tensor in norms.items(): - out_dict[name] = tensor + + # Compute norms + norms = self._compute_norm_per_property(transformed_outputs) + for name, tensor in norms.items(): + out_dict[name] = tensor - # Compute the character projections - convolution_integrals = self._compute_conv_integral(transformed_outputs) - for name, integral in convolution_integrals.items(): - out_dict[name] = integral + # Compute the character projections + convolution_integrals = self._compute_conv_integral(transformed_outputs) + for name, integral in convolution_integrals.items(): + out_dict[name] = integral return out_dict @@ -904,7 +900,7 @@ def _compute_norm_per_property( tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] ) - norms[name + "componentwise_norm_squared"] = tensor_norm + norms[name + "_componentwise_norm_squared"] = tensor_norm return norms def _compute_conv_integral( @@ -1309,7 +1305,19 @@ def _eval_over_grid( perm = _permute_system_before_atom(joined[0].samples.names) joined = mts.permute_dimensions(joined, "samples", perm) backtransformed_outputs_tensor[name] = joined - + + if "energy" in transformed_outputs_tensor: + energy_tm = transformed_outputs_tensor["energy"] + if "atom" in energy_tm[0].samples.names: + # Sum over atoms while keeping system and rotation indices. + energy_total_tm = mts.sum_over_samples(energy_tm, ["atom"]) + transformed_outputs_tensor["energy_total"] = energy_total_tm + + if "energy" in backtransformed_outputs_tensor: + energy_tm_bt = backtransformed_outputs_tensor["energy"] + if "atom" in energy_tm_bt[0].samples.names: + energy_total_tm_bt = mts.sum_over_samples(energy_tm_bt, ["atom"]) + backtransformed_outputs_tensor["energy_total"] = energy_total_tm_bt return transformed_outputs_tensor, backtransformed_outputs_tensor From 8ea74662dd8df47d6eb64c659f2bd132a532305f Mon Sep 17 00:00:00 2001 From: ppegolo Date: Sat, 29 Nov 2025 09:29:13 +0100 Subject: [PATCH 50/57] fix stress decomposition --- .../metatomic/torch/symmetrized_model.py | 140 ++++++++++-------- 1 file changed, 78 insertions(+), 62 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 71aa1ddc..4a9bfd71 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -1,19 +1,23 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + import metatensor.torch as mts + if TYPE_CHECKING: - class TensorBlock: - ... - class System: - ... - class TensorMap: - ... - class ModelOutput: - ... - class Labels: - ... + + class TensorBlock: ... + + class System: ... + + class TensorMap: ... + + class ModelOutput: ... + + class Labels: ... + else: from metatensor.torch import Labels, TensorBlock, TensorMap + from metatomic.torch import ModelOutput, System import numpy as np @@ -312,53 +316,52 @@ def _angles_from_rotations( gammas = gammas.reshape(batch_shape) return alphas, betas, gammas -def _l0_components_from_matrices( - A: torch.Tensor -) -> torch.Tensor: + +def _l0_components_from_matrices(A: torch.Tensor) -> torch.Tensor: """ Extract the L=0 components from a (3, 3) tensor. """ - # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at the end + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end A = A.permute(0, 3, 1, 2) # Test if the last two dimensions are (3, 3) assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." - # Initialize the output tensor for L=0 components to have 1 component in the last dimension - l0_A = torch.empty(A.shape[:-2] + (1,), - dtype=A.dtype, - device=A.device) + # Initialize the output tensor for L=0 components to have 1 component in the last + # dimension + l0_A = torch.empty(A.shape[:-2] + (1,), dtype=A.dtype, device=A.device) # Compute the L=0 component as the trace of A - l0_A[..., 0] = (A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2]) + l0_A[..., 0] = A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2] l0_A = l0_A.permute(0, 2, 1) return l0_A -def _l2_components_from_matrices( - A: torch.Tensor -) -> torch.Tensor: +def _l2_components_from_matrices(A: torch.Tensor) -> torch.Tensor: """ Extract the L=2 components from a (3, 3) tensor. """ - # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at the end + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end A = A.permute(0, 3, 1, 2) # Test if the last two dimensions are (3, 3) assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." - # Initialize the output tensor for L=2 components to have 5 components in the last dimension - l2_A = torch.empty(A.shape[:-2] + (5,), - dtype=A.dtype, - device=A.device) + # Initialize the output tensor for L=2 components to have 5 components in the last + # dimension + l2_A = torch.empty(A.shape[:-2] + (5,), dtype=A.dtype, device=A.device) - l2_A[..., 0] = (A[..., 0, 1] + A[..., 1, 0]) / 2.0 - l2_A[..., 1] = (A[..., 1, 2] + A[..., 2, 1]) / 2.0 - l2_A[..., 2] = (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / ((2.0) * np.sqrt(3.0)) + l2_A[..., 0] = (A[..., 0, 1] + A[..., 1, 0]) / 2.0 + l2_A[..., 1] = (A[..., 1, 2] + A[..., 2, 1]) / 2.0 + l2_A[..., 2] = (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / ( + (2.0) * np.sqrt(3.0) + ) l2_A[..., 3] = (A[..., 0, 2] + A[..., 2, 0]) / 2.0 l2_A[..., 4] = (A[..., 0, 0] - A[..., 1, 1]) / 2.0 l2_A = l2_A.permute(0, 2, 1) - + return l2_A @@ -633,7 +636,9 @@ def __init__( # Compute grid (unchanged) lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) if lebedev_order < 2 * self.max_o3_lambda_character: - print("Warning: Lebedev order may be insufficient for character projections.") + print( + "Warning: Lebedev order may be insufficient for character projections." + ) alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( lebedev_order, n_inplane_rotations ) @@ -686,7 +691,7 @@ def __init__( if isinstance(ch, np.ndarray): ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU name = f"so3_characters_l{ell}" self.register_buffer(name, ch) self._so3_char_names[ell] = name @@ -697,7 +702,7 @@ def __init__( if isinstance(ch, np.ndarray): ch = torch.from_numpy(ch) - ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU name = f"pso3_characters_l{ell}" self.register_buffer(name, ch) self._pso3_char_names[ell] = name @@ -796,7 +801,7 @@ def forward( mean_var = self._compute_mean_and_variance(backtransformed_outputs) for name, tensor in mean_var.items(): out_dict[name] = tensor - + # Compute norms norms = self._compute_norm_per_property(transformed_outputs) for name, tensor in norms.items(): @@ -807,11 +812,11 @@ def forward( for name, integral in convolution_integrals.items(): out_dict[name] = integral - return out_dict def _decompose_stress_tensor( - self, tensor_dict: Dict[str, TensorMap], + self, + tensor_dict: Dict[str, TensorMap], ) -> Dict[str, TensorMap]: """ Decompose stress tensor into irreducible representations of O(3). @@ -823,42 +828,53 @@ def _decompose_stress_tensor( return tensor_dict else: tensor = tensor_dict["non_conservative_stress"] - blocks_l0 = [] - blocks_l2 = [] - for block in tensor.blocks(): - + blocks: List[TensorBlock] = [] + keys: List[List[int]] = [] + for key, block in tensor.items(): + inversion = int(key["inversion"]) trace_values = _l0_components_from_matrices(block.values) block_l0 = TensorBlock( values=trace_values, samples=block.samples, - components=[Labels( - names="o3_mu", - values=torch.tensor([[0]], device=block.values.device, dtype=torch.int32) - )], + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], properties=block.properties, ) - blocks_l0.append(block_l0) + keys.append([0, 1, inversion]) + blocks.append(block_l0) block_l2 = TensorBlock( values=_l2_components_from_matrices(block.values), samples=block.samples, - components=[Labels( - names="o3_mu", - values=torch.tensor([[mu] for mu in range(-2, 3)], device=block.values.device, dtype=torch.int32) - )], + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-2, 3)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], properties=block.properties, ) - blocks_l2.append(block_l2) - - tensor_dict["non_conservative_stress_l0"] = TensorMap( - tensor.keys, - blocks_l0) - tensor_dict["non_conservative_stress_l2"] = TensorMap( - tensor.keys, - blocks_l2) - tensor_dict.pop("non_conservative_stress") + keys.append([2, 1, inversion]) + blocks.append(block_l2) + + tensor_dict["non_conservative_stress"] = TensorMap( + Labels( + ["o3_lambda", "o3_sigma", "inversion"], + torch.tensor(keys, device=blocks[0].values.device), + ), + blocks, + ) return tensor_dict - def _compute_norm_per_property( self, tensor_dict: Dict[str, TensorMap] @@ -1305,7 +1321,7 @@ def _eval_over_grid( perm = _permute_system_before_atom(joined[0].samples.names) joined = mts.permute_dimensions(joined, "samples", perm) backtransformed_outputs_tensor[name] = joined - + if "energy" in transformed_outputs_tensor: energy_tm = transformed_outputs_tensor["energy"] if "atom" in energy_tm[0].samples.names: From 428d38dd166556dcb1cc4d81c1858bc3ad4336e2 Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Mon, 1 Dec 2025 07:54:05 +0100 Subject: [PATCH 51/57] Clarified docstring --- python/metatomic_torch/metatomic/torch/symmetrized_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 4a9bfd71..e7244df8 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -540,9 +540,9 @@ class SymmetrizedModel(torch.nn.Module): Lebedev grid supplemented by in-plane rotations. For each sampled group element, the model outputs are "back-rotated" according to the known :math:`O(3)` action appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these - back-rotated predictions over the quadrature grid yields fully - :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance - metrics are computed: + back-rotated predictions over the quadrature grid yields outputs that are + :math:`O(3)`-symmetrized with an accuracy that depends on the resolution of the grid. + In addition, two complementary equivariance metrics are computed: 1. Variance under :math:`O(3)` of the back-rotated outputs. From 81d18d5d6b51a0e021a29cd69f9fa417894d3cfa Mon Sep 17 00:00:00 2001 From: ppegolo Date: Mon, 1 Dec 2025 12:23:28 +0100 Subject: [PATCH 52/57] fix bug for batched system evaluation --- .../metatomic/torch/symmetrized_model.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index e7244df8..1990cf7a 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -1,3 +1,4 @@ +import warnings from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import metatensor.torch as mts @@ -469,11 +470,12 @@ def _character_convolution( chi: torch.Tensor, block1: TensorBlock, block2: TensorBlock, w: torch.Tensor ) -> TensorBlock: """ - Compute the character convolution between a block containing SO(3)-sampled tensors. + Compute the character convolution of a block containing SO(3)-sampled tensors. Then contract with another block. """ samples = block1.samples assert samples.names[0] == "so3_rotation" + n_rot = chi.size(0) # torch.unique(samples.values[:, 0]).size(0) components = block1.components properties = block1.properties values = block1.values @@ -482,8 +484,10 @@ def _character_convolution( weight = w.to(dtype=values.dtype, device=values.device) # reshape the values to separate rotations from the other samples - new_shape = [n_rot, -1] + list(values.shape[1:]) - reshaped_values = values.reshape(new_shape) + split_sizes = [n_rot] * (values.shape[0] // n_rot) + reshaped_values = torch.stack(torch.split(values, split_sizes, dim=0)) + perm_shape = [1, 0] + list(range(2, reshaped_values.ndim)) + reshaped_values = reshaped_values.permute(perm_shape) # broadcast weights to match reshaped_values view: List[int] = [] @@ -500,8 +504,10 @@ def _character_convolution( values2 = block2.values # reshape the values to separate rotations from the other samples - new_shape = [n_rot, -1] + list(values2.shape[1:]) - reshaped_values2 = values2.reshape(new_shape) + split_sizes = [n_rot] * (values2.shape[0] // n_rot) + reshaped_values2 = torch.stack(torch.split(values2, split_sizes, dim=0)) + perm_shape = [1, 0] + list(range(2, reshaped_values2.ndim)) + reshaped_values2 = reshaped_values2.permute(perm_shape) # broadcast weights to match reshaped_values2 view: List[int] = [] @@ -540,9 +546,9 @@ class SymmetrizedModel(torch.nn.Module): Lebedev grid supplemented by in-plane rotations. For each sampled group element, the model outputs are "back-rotated" according to the known :math:`O(3)` action appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these - back-rotated predictions over the quadrature grid yields outputs that are - :math:`O(3)`-symmetrized with an accuracy that depends on the resolution of the grid. - In addition, two complementary equivariance metrics are computed: + back-rotated predictions over the quadrature grid yields fully + :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance + metrics are computed: 1. Variance under :math:`O(3)` of the back-rotated outputs. @@ -629,15 +635,16 @@ def __init__( self.max_o3_lambda_target = max_o3_lambda_target self.batch_size = batch_size if max_o3_lambda_character is None: - max_o3_lambda_grid = 2 * max_o3_lambda_character + 1 + max_o3_lambda_character = (max_o3_lambda_grid - 1) // 2 self.max_o3_lambda_grid = max_o3_lambda_grid self.max_o3_lambda_character = max_o3_lambda_character # Compute grid (unchanged) lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) if lebedev_order < 2 * self.max_o3_lambda_character: - print( - "Warning: Lebedev order may be insufficient for character projections." + warnings.warn( + "Lebedev order may be insufficient for character projections.", + stacklevel=2, ) alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( lebedev_order, n_inplane_rotations From 6b21fefcba7574969273f56bcf6f0c04a100f903 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Mon, 1 Dec 2025 13:37:53 +0100 Subject: [PATCH 53/57] start moving SymmetrizedModel methods to external utils --- .../metatomic/torch/symmetrized_model.py | 632 ++++++++++-------- 1 file changed, 337 insertions(+), 295 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 1990cf7a..a5d2713f 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -16,6 +16,8 @@ class ModelOutput: ... class Labels: ... + class ModelInterface: ... + else: from metatensor.torch import Labels, TensorBlock, TensorMap @@ -25,7 +27,7 @@ class Labels: ... import torch from metatrain.utils.augmentation import _apply_augmentations -from metatomic.torch import register_autograd_neighbors +from metatomic.torch import ModelInterface, register_autograd_neighbors try: @@ -1014,6 +1016,58 @@ def _compute_conv_integral( new_tensors[name + "_character_projection"] = new_tensor return new_tensors + def _eval_over_grid( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + return_transformed = True + + # Evaluate the model over the grid + results = evaluate_model_over_grid( + self.base_model, + self.batch_size, + self.so3_rotations, + self.so3_inverse_rotations, + self._wigner_D_inverse_jit, + return_transformed, + systems, + outputs, + selected_atoms, + ) + + if return_transformed: + transformed_outputs_tensor, backtransformed_outputs_tensor = results + else: + backtransformed_outputs_tensor = results + transformed_outputs_tensor: Dict[str, TensorMap] = {} + + # TODO: possibly remove + if "energy" in transformed_outputs_tensor: + energy_tm = transformed_outputs_tensor["energy"] + if "atom" in energy_tm[0].samples.names: + # Sum over atoms while keeping system and rotation indices. + energy_total_tm = mts.sum_over_samples(energy_tm, ["atom"]) + transformed_outputs_tensor["energy_total"] = energy_total_tm + + if "energy" in backtransformed_outputs_tensor: + energy_tm_bt = backtransformed_outputs_tensor["energy"] + if "atom" in energy_tm_bt[0].samples.names: + energy_total_tm_bt = mts.sum_over_samples(energy_tm_bt, ["atom"]) + backtransformed_outputs_tensor["energy_total"] = energy_total_tm_bt + return transformed_outputs_tensor, backtransformed_outputs_tensor + def _compute_mean_and_variance( self, tensor_dict: Dict[str, TensorMap], @@ -1024,324 +1078,209 @@ def _compute_mean_and_variance( :param tensor_dict: dictionary of TensorMaps to compute mean and variance for :return: dictionary of TensorMaps with mean and variance """ - mean_var: Dict[str, TensorMap] = {} - for name in tensor_dict: - if "features" in name: - continue - tensor = tensor_dict[name] - mean_blocks: List[TensorBlock] = [] - second_moment_blocks: List[TensorBlock] = [] - for block in tensor.blocks(): - rot_ids = block.samples.column("so3_rotation") - values = block.values - if values.ndim > 2: - dims: List[int] = [] - for i in range(1, values.ndim - 1): - dims.append(i) - values_squared = torch.sum(values**2, dim=dims) - else: - values_squared = values**2 + return symmetrize_over_grid(tensor_dict, self.so3_weights) - view: List[int] = [] - view.append(values.size(0)) - for _ in range(values.ndim - 1): - view.append(1) - values = 0.5 * self.so3_weights[rot_ids].view(view) * values - view: List[int] = [] - view.append(values_squared.size(0)) - for _ in range(values_squared.ndim - 1): - view.append(1) - values_squared = ( - 0.5 * self.so3_weights[rot_ids].view(view) * values_squared - ) +def evaluate_model_over_grid( + model: ModelInterface, + batch_size: int, + so3_rotations: torch.Tensor, + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], + return_transformed: bool, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, +) -> Dict[str, TensorMap] | Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Sample the model on the O(3) quadrature. - mean_blocks.append( - TensorBlock( - values=values, - samples=block.samples, - components=block.components, - properties=block.properties, + :param model: atomistic model to evaluate + :param systems: list of systems to evaluate + :param so3_rotations: SO(3) rotation matrices to use for sampling + :param batch_size: number of rotations to evaluate in a single batch + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :return: model outputs evaluated on the rotated systems. It is a dictionary + where each key is a target name and each value is a list of length len(systems). + Each element of the list is a dictionary with keys -1 and 1 (for improper + and proper rotations) and values the corresponding TensorMap outputs of the + model on the SO(3) quadrature. + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs: Dict[str, List[Dict[int, TensorMap]]] = {} + for name in outputs: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + transformed_outputs[name] = lst + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs: List[Dict[str, TensorMap]] = [] + for batch in range(0, len(so3_rotations), batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) ) - ) - second_moment_blocks.append( - TensorBlock( - values=values_squared, - samples=block.samples, - components=[], - properties=block.properties, + for R in so3_rotations[batch : batch + batch_size] + ] + with torch.no_grad(): + out = model( + transformed_systems, + outputs, + selected_atoms, ) + rotation_outputs.append(out) + + # Combine batch outputs + for name in outputs: + combined_: List[TensorMap] = [r[name] for r in rotation_outputs] + combined = mts.join( + combined_, + "samples", + add_dimension="batch_rotation", ) + if "batch_rotation" in combined[0].samples.names: + # Reindex + blocks: List[TensorBlock] = [] + for block in combined.blocks(): + batch_id = block.samples.column("batch_rotation") + rot_id = block.samples.column("system") + new_sample_values = block.samples.values[:, :-1] + new_sample_values[:, 0] = batch_id * batch_size + rot_id + blocks.append( + TensorBlock( + values=block.values.detach(), + samples=Labels( + block.samples.names[:-1], + new_sample_values, + ), + components=block.components, + properties=block.properties, + ) + ) + combined = TensorMap(combined.keys, blocks) + transformed_outputs[name][i_sys][inversion] = combined - # Mean - tensor_mean = TensorMap(tensor.keys, mean_blocks) - tensor_mean = mts.sum_over_samples( - tensor_mean.keys_to_samples("inversion"), ["inversion", "so3_rotation"] - ) + backtransformed_outputs = _backtransform_outputs( + transformed_outputs, systems, so3_rotations_inverse, wigner_D_inverse + ) + backtransformed_outputs_tensor = _to_metatensor(backtransformed_outputs, systems) - # Mean norm - mean_norm_squared_blocks: List[TensorBlock] = [] - for block in tensor_mean.blocks(): - vals = block.values - if vals.ndim > 2: - dims: List[int] = [] - for i in range(1, vals.ndim - 1): - dims.append(i) - vals = torch.sum(vals**2, dim=dims) - else: - vals = vals**2 - mean_norm_squared_blocks.append( - TensorBlock( - values=vals, - samples=block.samples, - components=[], - properties=block.properties, - ) - ) - tensor_mean_norm_squared = TensorMap( - tensor_mean.keys, mean_norm_squared_blocks - ) + if return_transformed: + transformed_outputs_tensor = _to_metatensor(transformed_outputs, systems) + return transformed_outputs_tensor, backtransformed_outputs_tensor + else: + transformed_outputs_tensor: Dict[str, TensorMap] = {} + return backtransformed_outputs_tensor - # Second moment - tensor_second_moment = TensorMap(tensor.keys, second_moment_blocks) - tensor_second_moment = mts.sum_over_samples( - tensor_second_moment.keys_to_samples("inversion"), - ["inversion", "so3_rotation"], - ) - # Variance - tensor_variance = mts.subtract( - tensor_second_moment, tensor_mean_norm_squared +def _to_metatensor( + tensor_dict: Dict[str, TensorMap], systems: List[System] +) -> Dict[str, TensorMap]: + """ + Convert the outputs of the model evaluated on rotated systems to a single + TensorMap per property, with appropriate dimensions for O(3) symmetrization. + """ + + out_tensor_dict: Dict[str, TensorMap] = {} + # Massage outputs to have desired shape + for name in tensor_dict: + joined_plus = mts.join( + [tensor_dict[name][i_sys][1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [tensor_dict[name][i_sys][-1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") + + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension(joined, "samples", "phys_system", "system") + else: + joined = mts.insert_dimension( + joined, + "samples", + 1, + "system", + torch.zeros( + joined[0].samples.values.shape[0], + dtype=torch.long, + device=joined[0].samples.values.device, + ), ) + if "atom" in joined[0].samples.names: + perm = _permute_system_before_atom(joined[0].samples.names) + joined = mts.permute_dimensions(joined, "samples", perm) + out_tensor_dict[name] = joined - mean_var[name + "_mean"] = tensor_mean - mean_var[name + "_norm_squared"] = tensor_second_moment - mean_var[name + "_var"] = tensor_variance - return mean_var + return out_tensor_dict - def _eval_over_grid( - self, - systems: List[System], - outputs: Dict[str, ModelOutput], - selected_atoms: Optional[Labels], - ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: - """ - Sample the model on the O(3) quadrature. - :param systems: list of systems to evaluate - :param model: atomistic model to evaluate - :param device: device to use for computation - :return: list of list of model outputs, shape (len(systems), N) - where N is the number of quadrature points - """ +def _backtransform_outputs( + tensor_dict: Dict[str, List[Dict[int, TensorMap]]], + systems: List[System], + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], +) -> Dict[str, List[Dict[int, TensorMap]]]: + """ + Given the outputs of the model evaluated on rotated systems, backtransform them to + the original frame according to the equivariance labels in the TensorMap keys. + """ - device = systems[0].positions.device - dtype = systems[0].positions.dtype + device = systems[0].positions.device + dtype = systems[0].positions.dtype - transformed_outputs = torch.jit.annotate( - Dict[str, List[Dict[int, TensorMap]]], {} - ) - for name in outputs: - lst = torch.jit.annotate(List[Dict[int, TensorMap]], []) - for _ in systems: - d = torch.jit.annotate(Dict[int, TensorMap], {}) - lst.append(d) - transformed_outputs[name] = lst - backtransformed_outputs = torch.jit.annotate( - Dict[str, List[Dict[int, TensorMap]]], {} - ) - for name in outputs: - lst = torch.jit.annotate(List[Dict[int, TensorMap]], []) - for _ in systems: - d = torch.jit.annotate(Dict[int, TensorMap], {}) - lst.append(d) - backtransformed_outputs[name] = lst + backtransformed_tensor_dict: Dict[str, List[Dict[int, TensorMap]]] = {} + for name in tensor_dict: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + backtransformed_tensor_dict[name] = lst + n_rot = so3_rotations_inverse.size(0) + for name in tensor_dict: for i_sys, system in enumerate(systems): for inversion in [-1, 1]: - rotation_outputs: List[Dict[str, TensorMap]] = [] - for batch in range(0, len(self.so3_rotations), self.batch_size): - transformed_systems = [ - _transform_system( - system, inversion * R.to(device=device, dtype=dtype) - ) - for R in self.so3_rotations[batch : batch + self.batch_size] - ] - with torch.no_grad(): - out = self.base_model( - transformed_systems, - outputs, - selected_atoms, - ) - rotation_outputs.append(out) - - # Combine batch outputs - for name in outputs: - combined_: List[TensorMap] = [r[name] for r in rotation_outputs] - combined = mts.join( - combined_, - "samples", - add_dimension="batch_rotation", - ) - if "batch_rotation" in combined[0].samples.names: - # Reindex - blocks: List[TensorBlock] = [] - for block in combined.blocks(): - batch_id = block.samples.column("batch_rotation") - rot_id = block.samples.column("system") - new_sample_values = block.samples.values[:, :-1] - new_sample_values[:, 0] = ( - batch_id * self.batch_size + rot_id - ) - blocks.append( - TensorBlock( - values=block.values.detach(), - samples=Labels( - block.samples.names[:-1], - new_sample_values, - ), - components=block.components, - properties=block.properties, - ) - ) - combined = TensorMap(combined.keys, blocks) - transformed_outputs[name][i_sys][inversion] = combined - - n_rot = self.so3_rotations.size(0) - for name in transformed_outputs: - for i_sys, system in enumerate(systems): - for inversion in [-1, 1]: - tensor = transformed_outputs[name][i_sys][inversion] - wigner_dict = torch.jit.annotate(Dict[int, List[torch.Tensor]], {}) - for ell in self._wigner_D_inverse_jit: - wigner_dict[ell] = ( - self._get_wigner_D_inverse(ell) - .to(device=device, dtype=dtype) - .unbind(0) - ) - - _, backtransformed, _ = _apply_augmentations( - [system] * n_rot, - {name: tensor}, - list( - ( - self.so3_inverse_rotations.to( - device=device, dtype=dtype - ) - * inversion - ).unbind(0) - ), - wigner_dict, + tensor = tensor_dict[name][i_sys][inversion] + wigner_dict: Dict[int, List[torch.Tensor]] = {} + for ell in wigner_D_inverse: + wigner_dict[ell] = ( + wigner_D_inverse[ell].to(device=device, dtype=dtype).unbind(0) ) - backtransformed_outputs[name][i_sys][inversion] = backtransformed[ - name - ] - - transformed_outputs_tensor: Dict[str, TensorMap] = {} - backtransformed_outputs_tensor: Dict[str, TensorMap] = {} - # Massage outputs to have desired shape - for name in transformed_outputs: - joined_plus = mts.join( - [transformed_outputs[name][i_sys][1] for i_sys in range(len(systems))], - "samples", - add_dimension="phys_system", - ) - joined_minus = mts.join( - [transformed_outputs[name][i_sys][-1] for i_sys in range(len(systems))], - "samples", - add_dimension="phys_system", - ) - joined = mts.join( - [ - mts.append_dimension(joined_plus, "keys", "inversion", 1), - mts.append_dimension(joined_minus, "keys", "inversion", -1), - ], - "samples", - different_keys="union", - ) - joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") - if "phys_system" in joined[0].samples.names: - joined = mts.rename_dimension( - joined, "samples", "phys_system", "system" - ) - else: - joined = mts.insert_dimension( - joined, - "samples", - 1, - "system", - torch.zeros( - joined[0].samples.values.shape[0], - dtype=torch.long, - device=joined[0].samples.values.device, + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + so3_rotations_inverse.to(device=device, dtype=dtype) + * inversion + ).unbind(0) ), + wigner_dict, ) - if "atom" in joined[0].samples.names: - perm = _permute_system_before_atom(joined[0].samples.names) - joined = mts.permute_dimensions(joined, "samples", perm) - transformed_outputs_tensor[name] = joined - - joined_plus = mts.join( - [ - backtransformed_outputs[name][i_sys][1] - for i_sys in range(len(systems)) - ], - "samples", - add_dimension="phys_system", - ) - joined_minus = mts.join( - [ - backtransformed_outputs[name][i_sys][-1] - for i_sys in range(len(systems)) - ], - "samples", - add_dimension="phys_system", - ) - joined = mts.join( - [ - mts.append_dimension(joined_plus, "keys", "inversion", 1), - mts.append_dimension(joined_minus, "keys", "inversion", -1), - ], - "samples", - different_keys="union", - ) - joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") - if "phys_system" in joined[0].samples.names: - joined = mts.rename_dimension( - joined, "samples", "phys_system", "system" - ) - else: - joined = mts.insert_dimension( - joined, - "samples", - 1, - "system", - torch.zeros( - joined[0].samples.values.shape[0], - dtype=torch.long, - device=joined[0].samples.values.device, - ), - ) - if "atom" in joined[0].samples.names: - perm = _permute_system_before_atom(joined[0].samples.names) - joined = mts.permute_dimensions(joined, "samples", perm) - backtransformed_outputs_tensor[name] = joined - - if "energy" in transformed_outputs_tensor: - energy_tm = transformed_outputs_tensor["energy"] - if "atom" in energy_tm[0].samples.names: - # Sum over atoms while keeping system and rotation indices. - energy_total_tm = mts.sum_over_samples(energy_tm, ["atom"]) - transformed_outputs_tensor["energy_total"] = energy_total_tm - - if "energy" in backtransformed_outputs_tensor: - energy_tm_bt = backtransformed_outputs_tensor["energy"] - if "atom" in energy_tm_bt[0].samples.names: - energy_total_tm_bt = mts.sum_over_samples(energy_tm_bt, ["atom"]) - backtransformed_outputs_tensor["energy_total"] = energy_total_tm_bt - return transformed_outputs_tensor, backtransformed_outputs_tensor + backtransformed_tensor_dict[name][i_sys][inversion] = backtransformed[ + name + ] + return backtransformed_tensor_dict def _permute_system_before_atom(labels: List[str]) -> List[int]: @@ -1368,3 +1307,106 @@ def _permute_system_before_atom(labels: List[str]) -> List[int]: perm.insert(atom_idx, v) return perm + + +def symmetrize_over_grid( + tensor_dict: Dict[str, TensorMap], + so3_weights: torch.Tensor, +) -> Dict[str, TensorMap]: + """ + Compute the mean and variance of the outputs over O(3). + + :param tensor_dict: dictionary of TensorMaps with rotated and backtransformed + outputs to compute mean, variance, and norm squared for + :param so3_weights: weights of the SO(3) quadrature + :return: dictionary of TensorMaps with mean, variance, and norm squared + """ + mean_var: Dict[str, TensorMap] = {} + for name in tensor_dict: + if "features" in name: + continue + tensor = tensor_dict[name] + mean_blocks: List[TensorBlock] = [] + second_moment_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values = block.values + if values.ndim > 2: + dims: List[int] = [] + for i in range(1, values.ndim - 1): + dims.append(i) + values_squared = torch.sum(values**2, dim=dims) + else: + values_squared = values**2 + + view: List[int] = [] + view.append(values.size(0)) + for _ in range(values.ndim - 1): + view.append(1) + values = 0.5 * so3_weights[rot_ids].view(view) * values + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = 0.5 * so3_weights[rot_ids].view(view) * values_squared + + mean_blocks.append( + TensorBlock( + values=values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + second_moment_blocks.append( + TensorBlock( + values=values_squared, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + + # Mean + tensor_mean = TensorMap(tensor.keys, mean_blocks) + tensor_mean = mts.sum_over_samples( + tensor_mean.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + # Mean norm + mean_norm_squared_blocks: List[TensorBlock] = [] + for block in tensor_mean.blocks(): + vals = block.values + if vals.ndim > 2: + dims: List[int] = [] + for i in range(1, vals.ndim - 1): + dims.append(i) + vals = torch.sum(vals**2, dim=dims) + else: + vals = vals**2 + mean_norm_squared_blocks.append( + TensorBlock( + values=vals, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + tensor_mean_norm_squared = TensorMap(tensor_mean.keys, mean_norm_squared_blocks) + + # Second moment + tensor_second_moment = TensorMap(tensor.keys, second_moment_blocks) + tensor_second_moment = mts.sum_over_samples( + tensor_second_moment.keys_to_samples("inversion"), + ["inversion", "so3_rotation"], + ) + + # Variance + tensor_variance = mts.subtract(tensor_second_moment, tensor_mean_norm_squared) + + mean_var[name + "_mean"] = tensor_mean + mean_var[name + "_norm_squared"] = tensor_second_moment + mean_var[name + "_var"] = tensor_variance + return mean_var From 59724c557cd2ec6161f6fa85f94ad0d2fad09a4b Mon Sep 17 00:00:00 2001 From: Joseph Abbott Date: Tue, 2 Dec 2025 22:35:57 +0100 Subject: [PATCH 54/57] Decomposition of energy forces, stress --- .../metatomic/torch/symmetrized_model.py | 144 ++++++++++++++---- 1 file changed, 113 insertions(+), 31 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index a5d2713f..6a5fc13f 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -801,8 +801,8 @@ def forward( systems, outputs, selected_atoms ) - transformed_outputs = self._decompose_stress_tensor(transformed_outputs) - backtransformed_outputs = self._decompose_stress_tensor(backtransformed_outputs) + transformed_outputs = self._decompose_tensors(transformed_outputs) + backtransformed_outputs = self._decompose_tensors(backtransformed_outputs) out_dict: Dict[str, TensorMap] = {} @@ -823,49 +823,81 @@ def forward( return out_dict - def _decompose_stress_tensor( + def _decompose_tensors( self, tensor_dict: Dict[str, TensorMap], ) -> Dict[str, TensorMap]: """ - Decompose stress tensor into irreducible representations of O(3). + Decompose tensors in the dictionary into irreducible representations of O(3). :param tensor_dict: dictionary of TensorMaps to decompose - :return: dictionary of TensorMaps with decomposed stress tensors + :return: dictionary of TensorMaps with decomposed tensors """ - if "non_conservative_stress" not in tensor_dict: + tensor_dict = self._decompose_energy_tensor(tensor_dict) + tensor_dict = self._decompose_forces_tensor(tensor_dict) + tensor_dict = self._decompose_stress_tensor(tensor_dict) + return tensor_dict + + def _decompose_energy_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose energy tensor into irreducible representations of O(3). + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed energy tensors + """ + if "energy" not in tensor_dict: return tensor_dict - else: - tensor = tensor_dict["non_conservative_stress"] - blocks: List[TensorBlock] = [] - keys: List[List[int]] = [] - for key, block in tensor.items(): - inversion = int(key["inversion"]) - trace_values = _l0_components_from_matrices(block.values) - block_l0 = TensorBlock( - values=trace_values, + + tensor = tensor_dict["energy"] + tensor_dict["energy_l0"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.unsqueeze(1), samples=block.samples, components=[ Labels( - names="o3_mu", + names=["o3_mu"], values=torch.tensor( - [[0]], device=block.values.device, dtype=torch.int32 + [[0]], device=self.so3_weights.device, dtype=torch.int32 ), ) ], properties=block.properties, ) - keys.append([0, 1, inversion]) - blocks.append(block_l0) + for block in tensor + ], + ) + tensor_dict.pop("energy") + return tensor_dict + + def _decompose_forces_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose forces tensor into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed forces tensors + """ + if "non_conservative_forces" not in tensor_dict: + return tensor_dict - block_l2 = TensorBlock( - values=_l2_components_from_matrices(block.values), + tensor = tensor_dict["non_conservative_forces"] + tensor_dict["non_conservative_forces_l1"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.roll(-1, 1), samples=block.samples, components=[ Labels( names="o3_mu", values=torch.tensor( - [[mu] for mu in range(-2, 3)], + [[mu] for mu in range(-1, 2)], device=block.values.device, dtype=torch.int32, ), @@ -873,18 +905,68 @@ def _decompose_stress_tensor( ], properties=block.properties, ) - keys.append([2, 1, inversion]) - blocks.append(block_l2) + for block in tensor + ], + ) + tensor_dict.pop("non_conservative_forces") + return tensor_dict - tensor_dict["non_conservative_stress"] = TensorMap( - Labels( - ["o3_lambda", "o3_sigma", "inversion"], - torch.tensor(keys, device=blocks[0].values.device), - ), - blocks, - ) + def _decompose_stress_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose stress tensor into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed stress tensors + """ + if "non_conservative_stress" not in tensor_dict: return tensor_dict + tensor = tensor_dict["non_conservative_stress"] + blocks_l0 = [] + blocks_l2 = [] + for block in tensor.blocks(): + trace_values = _l0_components_from_matrices(block.values) + block_l0 = TensorBlock( + values=trace_values, + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + blocks_l0.append(block_l0) + + block_l2 = TensorBlock( + values=_l2_components_from_matrices(block.values), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-2, 3)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + blocks_l2.append(block_l2) + + tensor_dict["non_conservative_stress_l0"] = TensorMap(tensor.keys, blocks_l0) + tensor_dict["non_conservative_stress_l2"] = TensorMap(tensor.keys, blocks_l2) + tensor_dict.pop("non_conservative_stress") + + return tensor_dict + def _compute_norm_per_property( self, tensor_dict: Dict[str, TensorMap] ) -> Dict[str, TensorMap]: From 537ad94c40f776d280f21ce38fd35ebf55a5a605 Mon Sep 17 00:00:00 2001 From: Joseph Abbott Date: Tue, 2 Dec 2025 22:36:22 +0100 Subject: [PATCH 55/57] Make compatible with diagnostic 'deep' tokens --- .../metatomic/torch/symmetrized_model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 6a5fc13f..b88c7368 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -618,10 +618,10 @@ class SymmetrizedModel(torch.nn.Module): def __init__( self, base_model, - max_o3_lambda_grid: int, + max_o3_lambda_character: int, max_o3_lambda_target: int, batch_size: int = 32, - max_o3_lambda_character: Optional[int] = None, + max_o3_lambda_grid: Optional[int] = None, ): super().__init__() self.base_model = base_model @@ -636,8 +636,8 @@ def __init__( self.max_o3_lambda_target = max_o3_lambda_target self.batch_size = batch_size - if max_o3_lambda_character is None: - max_o3_lambda_character = (max_o3_lambda_grid - 1) // 2 + if max_o3_lambda_grid is None: + max_o3_lambda_grid = int(2 * max_o3_lambda_character + 1) self.max_o3_lambda_grid = max_o3_lambda_grid self.max_o3_lambda_character = max_o3_lambda_character @@ -1307,7 +1307,7 @@ def _to_metatensor( device=joined[0].samples.values.device, ), ) - if "atom" in joined[0].samples.names: + if "atom" in joined[0].samples.names or "first_atom" in joined[0].samples.names: perm = _permute_system_before_atom(joined[0].samples.names) joined = mts.permute_dimensions(joined, "samples", perm) out_tensor_dict[name] = joined @@ -1374,6 +1374,8 @@ def _permute_system_before_atom(labels: List[str]) -> List[int]: sys_idx = i elif labels[i] == "atom": atom_idx = i + elif labels[i] == "first_atom": + atom_idx = i # identity permutation perm = list(range(len(labels))) @@ -1405,6 +1407,8 @@ def symmetrize_over_grid( """ mean_var: Dict[str, TensorMap] = {} for name in tensor_dict: + # cannot compute a mean or variance as these have no known behaviour under + # rotations if "features" in name: continue tensor = tensor_dict[name] From 5867928ba840e8bd0be4ded16b9048ea2e186ff7 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Thu, 4 Dec 2025 14:57:35 +0100 Subject: [PATCH 56/57] Fix nasty bug --- .../metatomic/torch/symmetrized_model.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index b88c7368..06adc417 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -477,7 +477,7 @@ def _character_convolution( """ samples = block1.samples assert samples.names[0] == "so3_rotation" - n_rot = chi.size(0) # torch.unique(samples.values[:, 0]).size(0) + n_rot = chi.size(0) components = block1.components properties = block1.properties values = block1.values @@ -485,11 +485,15 @@ def _character_convolution( n_rot = chi.size(1) weight = w.to(dtype=values.dtype, device=values.device) - # reshape the values to separate rotations from the other samples - split_sizes = [n_rot] * (values.shape[0] // n_rot) - reshaped_values = torch.stack(torch.split(values, split_sizes, dim=0)) - perm_shape = [1, 0] + list(range(2, reshaped_values.ndim)) - reshaped_values = reshaped_values.permute(perm_shape) + split_sizes = torch.bincount(samples.values[:, 1]).tolist() + split_by_system = torch.split(values, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values = split_by_rotation # broadcast weights to match reshaped_values view: List[int] = [] @@ -505,11 +509,15 @@ def _character_convolution( ).reshape(contracted_shape) values2 = block2.values - # reshape the values to separate rotations from the other samples - split_sizes = [n_rot] * (values2.shape[0] // n_rot) - reshaped_values2 = torch.stack(torch.split(values2, split_sizes, dim=0)) - perm_shape = [1, 0] + list(range(2, reshaped_values2.ndim)) - reshaped_values2 = reshaped_values2.permute(perm_shape) + split_sizes = torch.bincount(block2.samples.values[:, 1]).tolist() + split_by_system = torch.split(values2, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values2 = split_by_rotation # broadcast weights to match reshaped_values2 view: List[int] = [] From 09ec4e3e2f55ea67819d371e3f243015e0edf557 Mon Sep 17 00:00:00 2001 From: Joseph Abbott Date: Sat, 13 Dec 2025 15:06:05 +0100 Subject: [PATCH 57/57] Add arg to switch off token projections --- .../metatomic_torch/metatomic/torch/symmetrized_model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py index 06adc417..9c66c86f 100644 --- a/python/metatomic_torch/metatomic/torch/symmetrized_model.py +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -794,6 +794,7 @@ def forward( systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, + project_tokens: bool = False, ) -> Dict[str, TensorMap]: """ Symmetrize the model outputs over :math:`O(3)` and compute equivariance @@ -806,7 +807,7 @@ def forward( """ # Evaluate the model over the grid transformed_outputs, backtransformed_outputs = self._eval_over_grid( - systems, outputs, selected_atoms + systems, outputs, selected_atoms, return_transformed=project_tokens, ) transformed_outputs = self._decompose_tensors(transformed_outputs) @@ -819,6 +820,9 @@ def forward( for name, tensor in mean_var.items(): out_dict[name] = tensor + if not project_tokens: + return out_dict + # Compute norms norms = self._compute_norm_per_property(transformed_outputs) for name, tensor in norms.items(): @@ -1111,6 +1115,7 @@ def _eval_over_grid( systems: List[System], outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels], + return_transformed: bool, ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: """ Sample the model on the O(3) quadrature. @@ -1122,8 +1127,6 @@ def _eval_over_grid( where N is the number of quadrature points """ - return_transformed = True - # Evaluate the model over the grid results = evaluate_model_over_grid( self.base_model,