diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 5c1469bf..b27af50e 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -69,5 +69,10 @@ def __getattr__(name): import metatomic.torch.ase_calculator return metatomic.torch.ase_calculator + + elif name == "rotational_utils": + import metatomic.torch.rotational_utils + + return metatomic.torch.rotational_utils else: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/metatomic_torch/metatomic/torch/rotational_utils.py b/python/metatomic_torch/metatomic/torch/rotational_utils.py new file mode 100644 index 00000000..a2a56132 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/rotational_utils.py @@ -0,0 +1,788 @@ +""" +Utilities for diagnosing rotational equivariance of models and for enforcing +rotational symmetry in data augmentation and model evaluation. +""" + +import warnings +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import TensorMap +from metatrain.utils.augmentation import ( + _apply_augmentations, + _complex_to_real_spherical_harmonics_transform, + _scipy_quaternion_to_quaternionic, +) + +import metatomic.torch # noqa: F401 +from metatomic.torch import ModelEvaluationOptions, System, register_autograd_neighbors +from metatomic.torch.model import AtomisticModel + + +try: + from scipy.spatial.transform import Rotation # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for Lebedev quadrature. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + from scipy.integrate import lebedev_rule + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (n_rotations,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (N,) + + return alpha, beta, gamma, w_so3 + + +def _rotations_from_angles(alpha, beta, gamma): + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def evaluate_model_on_quadrature(model, systems, L_max: int, device="cpu"): + pass + + +############ + + +def _extract_euler_zyz( + R: torch.Tensor, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + # R01 = R_flat[:, 0, 1] # unused + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + # R11 = R_flat[:, 1, 1] # unused + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = torch.clip(R22, -1.0, 1.0) + betas = torch.arccos(zz) + + # For Z-Y-Z, standard formulas away from the singular set + alphas = torch.arctan2(R12, R02) + gammas = torch.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * torch.pi + alphas = torch.remainder(alphas, two_pi) + gammas = torch.remainder(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = torch.sin(betas) + near = torch.abs(sinb) < eps + if torch.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if torch.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from 2x2 + # block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = torch.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = torch.remainder(alphas[near_zero], two_pi) + + if torch.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. + betas[near_pi] = torch.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = torch.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = torch.remainder(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def get_so3_character( + alphas: torch.Tensor, + betas: torch.Tensor, + gammas: torch.Tensor, + o3_lambda: int, + tol: float = 1e-7, +) -> torch.Tensor: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_ℓ(ω) = sin((2ℓ+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = torch.cos(betas / 2.0) * torch.cos((alphas + gammas) / 2.0) + cos_t = torch.clip(cos_t, -1.0, 1.0) + t = torch.arccos(cos_t) + + # Output array + chi = torch.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = torch.abs(t) < tol + if torch.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if torch.any(large): + tl = t[large] + sin_t = torch.sin(tl) + numer = torch.sin(a * tl) + mask = torch.abs(sin_t) >= tol + out = torch.empty_like(tl) + torch.div(numer, sin_t, out=out) # TODO figure out with numpy divide + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def get_so3_characters_dict( + alphas: torch.Tensor, betas: torch.Tensor, gammas: torch.Tensor, o3_lambda_max: int +) -> Dict[int, torch.Tensor]: + """ + Returns a dictionary of the SO(3) characters for all o3_lambda in [0, + o3_lambda_max]. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + characters[o3_lambda] = get_so3_character(alphas, betas, gammas, o3_lambda) + return characters + + +def get_pso3_characters_dict( + so3_character: Dict[int, torch.Tensor], o3_lambda_max: int +) -> Dict[Tuple[int, int], torch.Tensor]: + """ + Returns a dictionary of the P⋅SO(3) characters for all (o3_lambda, o3_sigma) pairs + with o3_lambda in [0, o3_lambda_max] and o3_sigma in {-1, +1}. + Requires a pre-computed dictionary of SO(3) characters. + """ + characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + characters[(o3_lambda, o3_sigma)] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_character[o3_lambda] + ) + return characters + + +############ + + +class O3Sampler: + """ + Compute model predictions on a quadrature over the O(3) group. + + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch. + """ + + def __init__(self, quad_l_max: int, project_l_max: int, batch_size: int = 1): + try: + from scipy.spatial.transform import Rotation # noqa: F401 + except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e + + self.quad_l_max = quad_l_max + """Maximum spherical harmonic degree for quadrature.""" + + self.project_l_max = project_l_max + """Maximum spherical harmonic degree to project onto.""" + if self.project_l_max + 2 > self.quad_l_max: + warnings.warn( + ( + f"Projecting up to l={self.project_l_max} with quadrature up " + f"to l={self.quad_l_max} may be inaccurate." + ), + stacklevel=2, + ) + + # Get the quadrature + self.lebedev_order: int + """Number of Lebedev nodes on the unit sphere.""" + + self.n_inplane_rotations: int + """Number of in-plane rotations per Lebedev node.""" + self.lebedev_order, self.n_inplane_rotations = _choose_quadrature( + self.quad_l_max + ) + + self.w_so3: torch.Tensor + """Weights associated to each rotation in the SO(3) Haar measure.""" + + alpha, beta, gamma, self.w_so3 = get_euler_angles_quadrature( + self.lebedev_order, self.n_inplane_rotations + ) + self.w_so3 = torch.from_numpy(self.w_so3) + + # For active rotation of systems + self.R_so3 = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ) + """Rotation matrices.""" + + self.n_rotations = self.R_so3.size(0) + + # For inverse rotation of tensors + R_pso3 = _rotations_from_angles(np.pi - alpha, beta, np.pi - gamma) + self.wigner_D: Dict[int, torch.Tensor] = _compute_wigner_D_matrices( + self.project_l_max, R_pso3 + ) + """Dict mapping l to (N, (2l+1), (2l+1)) torch.Tensor of Wigner D matrices.""" + + self.R_pso3 = torch.from_numpy(R_pso3.as_matrix()) + """Inverse rotation matrices.""" + + self.batch_size = batch_size + + _external_product_euler_angles = _extract_euler_zyz( + (self.R_so3[:, None, :, :] @ self.R_pso3[None, :, :, :]), eps=1e-6 + ) + self.so3_characters = get_so3_characters_dict( + *_external_product_euler_angles, self.project_l_max + ) + self.pso3_characters = get_pso3_characters_dict( + self.so3_characters, self.project_l_max + ) + + def evaluate( + self, + model: AtomisticModel, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ): + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + backtransformed_outputs = { + name: [{-1: None, 1: None} for _ in systems] + for name in options.outputs.keys() + } + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs = [] + for batch in range(0, len(self.R_so3), self.batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) + ) + for R in self.R_so3[batch : batch + self.batch_size] + ] + outputs = model( + transformed_systems, + options=options, + check_consistency=check_consistency, + ) + rotation_outputs.append(outputs) + + for name in transformed_outputs: + tensor = mts.join( + [r[name] for r in rotation_outputs], + "samples", + remove_tensor_name=True, + ) + transformed_outputs[name][i_sys][inversion] = mts.rename_dimension( + tensor, "samples", "tensor", "o3_sample" + ) + + n_rot = self.R_so3.size(0) + for name in transformed_outputs: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = transformed_outputs[name][i_sys][inversion] + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + self.R_pso3.to(device=device, dtype=dtype) * inversion + ).unbind(0) + ), + self.wigner_D, + ) + backtransformed_outputs[name][i_sys][inversion] = backtransformed[ + name + ] + + return transformed_outputs, backtransformed_outputs + + +class TokenProjector(torch.nn.Module): + """ + Wrap an atomistic model to project its predictions onto spherical sectors. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ) -> None: + super().__init__() + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: TODO + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + self.model, systems, options, check_consistency + ) + + # TODO do projection operations + pass + + +class SymmetrizedAtomisticModel(torch.nn.Module): + """ + Wrap an atomistic model to symmetrize its predictions over a quadrature and compute + O(3) averages, variances, and equivariance score. + + :param model: atomistic model to wrap + :param quad_l_max: maximum spherical harmonic degree for quadrature + :param project_l_max: maximum spherical harmonic degree to project onto + :param batch_size: number of rotations to process in a single batch + """ + + def __init__( + self, + model: AtomisticModel, + quad_l_max: int, + project_l_max: int, + batch_size: Optional[int] = None, + ): + super().__init__("SymmetrizedAtomisticModel") + self.model = model + """The underlying atomistic model.""" + self.o3_sampler = O3Sampler(quad_l_max, project_l_max, batch_size=batch_size) + """The projector onto spherical sectors.""" + + def forward( + self, + systems: List[System], + options: ModelEvaluationOptions, + check_consistency: bool = False, + ) -> torch.Tensor: + """ + :param systems: list of systems to evaluate + :param options: model evaluation options + :param check_consistency: whether to check model consistency + :return: + """ + + transformed_outputs, _ = self.o3_sampler.evaluate( + systems, self.model, options, check_consistency + ) + + return compute_projections( + self.o3_sampler.project_l_max, + systems, + transformed_outputs, + self.o3_sampler.w_so3, + self.o3_sampler.so3_characters, + self.o3_sampler.pso3_characters, + ) + + +def _compute_wigner_D_matrices( + l_max: int, + rotations: List["Rotation"], + complex_to_real: Optional[np.ndarray] = None, +) -> dict: + """ + Compute Wigner D matrices for all l <= project_l_max. + + :param l_max: maximum spherical harmonic degree + :param rotations: list of scipy Rotation objects + :param complex_to_real: optional dict mapping l to (2l+1, (2l+1)) array to convert + complex spherical harmonics to real spherical harmonics + :return: dict mapping l to (N, (2l+1), (2l+1)) array of Wigner D matrices + """ + + try: + import spherical + except ImportError as e: + # quaternionic (used below) is a dependency of spherical + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e + + wigner = spherical.Wigner(l_max) + scipy_quaternions = [r.as_quat() for r in rotations] + quaternionic_quaternions = [ + _scipy_quaternion_to_quaternionic(q) for q in scipy_quaternions + ] + wigner_D_matrices_complex = [wigner.D(q) for q in quaternionic_quaternions] + + if complex_to_real is None: + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(l_max + 1) + } + + wigner_D_matrices = {} + for ell in range(l_max + 1): + U = complex_to_real[ell] + wigner_D_matrices_l = [] + for wigner_D_matrix_complex in wigner_D_matrices_complex: + wigner_D_matrix = np.zeros((2 * ell + 1, 2 * ell + 1), dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + wigner_D_matrix[m + ell, mp + ell] = ( + wigner_D_matrix_complex[wigner.Dindex(ell, m, mp)] + ).conj() + + wigner_D_matrix = U.conj() @ wigner_D_matrix @ U.T + assert np.allclose(wigner_D_matrix.imag, 0.0) + wigner_D_matrix = wigner_D_matrix.real + wigner_D_matrices_l.append(torch.from_numpy(wigner_D_matrix)) + wigner_D_matrices[ell] = wigner_D_matrices_l + + return wigner_D_matrices + + +# O3-integrals utilities + + +def compute_projections( + max_l: int, + systems: List[System], + transformed_outputs: Dict[str, List[TensorMap]], + weights: torch.Tensor, + so3_characters: Dict[int, torch.Tensor], + pso3_characters: Dict[Tuple[int, int], torch.Tensor], +) -> Tuple[ + Dict[str, List[Dict[int, TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], + Dict[str, List[Dict[Tuple[int, int], TensorMap]]], +]: + """ + + TODO docstring, check type annotations + + - Take model outputs on a quadrature + - Manipulate dimensions + - Compute some integrals + - Return projections + + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + weights = weights.to(device, dtype) + so3_characters = {k: v.to(device, dtype) for k, v in so3_characters.items()} + pso3_characters = {k: v.to(device, dtype) for k, v in pso3_characters.items()} + + n_rotations = len(weights) + norms = {} + convolution_integrals = {} + normalized_convolution_integrals = {} + # Loop over targets + for name, transformed_output in transformed_outputs.items(): + norms[name] = [] + convolution_integrals[name] = [] + normalized_convolution_integrals[name] = [] + for o3_output_for_system in transformed_output: + proper = o3_output_for_system[1] + improper = o3_output_for_system[-1] + + # Weighting the tensors + broadcasted_w = ( + weights[proper[0].samples.column("o3_sample")] / 16 / torch.pi**2 + ) + proper_weighted = proper.copy() + improper_weighted = improper.copy() + for k in proper_weighted.keys: + proper_block = proper_weighted[k] + improper_block = improper_weighted[k] + proper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (proper_block.values.ndim - 1) + ) + improper_block.values[:] *= broadcasted_w.view( + -1, *[1] * (improper_block.values.ndim - 1) + ) + + # Compute norms + proper_norm = mts.multiply(proper, proper_weighted) + improper_norm = mts.multiply(improper, improper_weighted) + norm = mts.add(proper_norm, improper_norm) + norm = mts.sum_over_samples(norm, "o3_sample") + norms[name].append(norm) + + # Compute convolution integrals + convolution_integral = {} + normalized_convolution_integral = {} + for ell in range(max_l + 1): + so3_char = so3_characters[ell] + for sigma in [-1, 1]: + pso3_char = pso3_characters[(ell, sigma)] + + integral_blocks = [] + for k in proper.keys: + proper_block = proper[k].values.reshape( + -1, n_rotations, *proper[k].shape[1:] + ) + improper_block = improper[k].values.reshape( + -1, n_rotations, *improper[k].shape[1:] + ) + integral_values = ( + ( + 0.25 + * torch.einsum( + "ij...,nij...->n...", + so3_char, + proper_block[:, :, None, ...] + * proper_block[:, None, :, ...] + + improper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + + 0.5 + * torch.einsum( + "ij...,nij...->n...", + pso3_char, + proper_block[:, :, None, ...] + * improper_block[:, None, :, ...], + ) + ) + * (2 * ell + 1) + / (8 * torch.pi**2) ** 2 + ) + integral_blocks.append( + mts.TensorBlock( + samples=norm[k].samples, + components=norm[k].components, + properties=norm[k].properties, + values=integral_values, + ) + ) + convolution_integral[(ell, sigma)] = mts.TensorMap( + keys=norm.keys, blocks=integral_blocks + ) + normalized_convolution_integral[(ell, sigma)] = mts.divide( + convolution_integral[(ell, sigma)], norm + ) + convolution_integrals[name].append(convolution_integral) + normalized_convolution_integrals[name].append( + normalized_convolution_integral + ) + + return norms, convolution_integrals, normalized_convolution_integrals + + +# IO utilities + + +def norms_to(norms, dtype, device): + """Moves the TensorMap of norms to dtype and device""" + + norms_to = {} + for output_name in norms.keys(): + quantity_list = [] + for quantity in norms[output_name]: + quantity_list.append(quantity.to(dtype=dtype, device=device)) + norms_to[output_name] = quantity_list + + return norms_to + + +def integrals_to(integral, dtype, device): + """Moves the TensorMap of integrals to dtype and device""" + + integral_to = {} + for output_name in integral.keys(): + quantity_list = [] + for quantity_dict in integral[output_name]: + quantity_dict_to = {} + for key, quantity in quantity_dict.items(): + quantity_dict_to[key] = quantity.to(dtype=dtype, device=device) + quantity_list.append(quantity_dict_to) + integral_to[output_name] = quantity_list + + return integral_to diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py new file mode 100644 index 00000000..9c66c86f --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -0,0 +1,1509 @@ +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import metatensor.torch as mts + + +if TYPE_CHECKING: + + class TensorBlock: ... + + class System: ... + + class TensorMap: ... + + class ModelOutput: ... + + class Labels: ... + + class ModelInterface: ... + +else: + from metatensor.torch import Labels, TensorBlock, TensorMap + + from metatomic.torch import ModelOutput, System + +import numpy as np +import torch +from metatrain.utils.augmentation import _apply_augmentations + +from metatomic.torch import ModelInterface, register_autograd_neighbors + + +try: + from scipy.integrate import lebedev_rule # noqa: F401 + from scipy.spatial.transform import Rotation # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `scipy` package with `pip install scipy`." + ) from e +try: + import spherical # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `spherical` package with `pip install spherical`." + ) from e +try: + import quaternionic # noqa: F401 +except ImportError as e: + raise ImportError( + "To perform data augmentation on spherical targets, please " + "install the `quaternionic` package with `pip install quaternionic`." + ) from e + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for a Lebedev quadrature combined with in-plane + rotations for SO(3) integration. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (K,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (M*K,) + + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + return A, B, G, w_so3 + + +def _rotations_from_angles( + alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> Rotation: + """ + Compose rotations from ZYZ Euler angles. + + :param alpha: array of alpha angles (M,) + :param beta: array of beta angles (M,) + :param gamma: array of gamma angles (K,) + :return: Rotation object containing all (M*K,) rotations + """ + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", alpha) + * Rotation.from_euler("y", beta) + * Rotation.from_euler("z", gamma) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics + to real spherical harmonics for a given l. + Returns a transformation matrix of shape ((2l+1), (2l+1)). + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + # The size of the transformation matrix is (2l+1) x (2l+1) + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell # Index in the matrix + if m > 0: + # Real part of Y_{l}^{m} + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + # Imaginary part of Y_{l}^{|m|} + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: # m == 0 + # Y_{l}^{0} remains unchanged + T[m_index, ell] = 1 + + # Return the transformation matrix to convert complex to real spherical harmonics + return T + + +def _compute_real_wigner_matrices( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], # alpha, beta, gamma +) -> Dict[int, np.ndarray]: + wigner = spherical.Wigner(o3_lambda_max) + R = quaternionic.array.from_euler_angles(*angles) + D = wigner.D(R) + wigner_D_matrices = {} + for ell in range(o3_lambda_max + 1): + wigner_D_matrices[ell] = np.zeros( + angles[0].shape + (2 * ell + 1, 2 * ell + 1), dtype=np.complex128 + ) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + # There is an unexplained conjugation factor in the definition given in + # the quaternionic library. + wigner_D_matrices[ell][..., mp + ell, m + ell] = ( + D[..., wigner.Dindex(ell, mp, m)] + ).conj() + U = _complex_to_real_spherical_harmonics_transform(ell) + wigner_D_matrices[ell] = np.einsum( + "ij,...jk,kl->...il", U.conj(), wigner_D_matrices[ell], U.T + ) + assert np.allclose(wigner_D_matrices[ell].imag, 0) + wigner_D_matrices[ell] = torch.from_numpy(wigner_D_matrices[ell].real) + + return wigner_D_matrices + + +def _angles_from_rotations( + R: np.ndarray, + eps: float = 1e-6, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Extract Z-Y-Z Euler angles (alpha, beta, gamma) from rotation matrices, with + explicit handling of the gimbal-lock cases (beta≈0 and beta≈pi). + TODO: This function is extremely sensitive to eps and will be modified. + Parameters + ---------- + R : np.ndarray + Rotation matrices with arbitrary batch shape `(..., 3, 3)`. + eps : float + Tolerance used to detect gimbal lock via `sin(beta) < eps`. + + Returns + ------- + (alphas, betas, gammas) : Tuple[np.ndarray, np.ndarray, np.ndarray] + Each with the same batch shape as `R[..., 0, 0]` (i.e., `R.shape[:-2]`). + + Notes + ----- + Conventions: + - Base convention is Z-Y-Z (Rz(alpha) Ry(beta) Rz(gamma)). + - For beta≈0: set beta=0, gamma=0, alpha=atan2(R[1,0], R[0,0]). + - For beta≈pi: set beta=pi, alpha=0, gamma=atan2(R[1,0], -R[0,0]). + These conventions ensure a deterministic inverse where the standard formulas + are ill-conditioned. + """ + # Accept any batch shape. Flatten to (N, 3, 3) for clarity, then unflatten. + batch_shape = R.shape[:-2] + R_flat = R.reshape(-1, 3, 3) + + # Read commonly-used entries with explicit names for readability + R00 = R_flat[:, 0, 0] + # R01 = R_flat[:, 0, 1] + R02 = R_flat[:, 0, 2] + R10 = R_flat[:, 1, 0] + # R11 = R_flat[:, 1, 1] + R12 = R_flat[:, 1, 2] + R20 = R_flat[:, 2, 0] + R21 = R_flat[:, 2, 1] + R22 = R_flat[:, 2, 2] + + # Default (non-singular) extraction + zz = np.clip(R22, -1.0, 1.0) + betas = np.arccos(zz) + + # For Z–Y–Z, standard formulas away from the singular set + alphas = np.arctan2(R12, R02) + gammas = np.arctan2(R21, -R20) + + # Normalize into [0, 2π) + two_pi = 2.0 * np.pi + alphas = np.mod(alphas, two_pi) + gammas = np.mod(gammas, two_pi) + + # Gimbal-lock detection via sin(beta) + sinb = np.sin(betas) + near = np.abs(sinb) < eps + if np.any(near): + # Split the two singular bands using zz = cos(beta) + near_zero = near & (zz > 0) # beta≈0 + near_pi = near & (zz < 0) # beta≈pi + + if np.any(near_zero): + # beta≈0: rotation ≈ Rz(alpha+gamma). Choose gamma=0, recover alpha from + # 2x2 block. + betas[near_zero] = 0.0 + gammas[near_zero] = 0.0 + alphas[near_zero] = np.arctan2(R10[near_zero], R00[near_zero]) + alphas[near_zero] = np.mod(alphas[near_zero], two_pi) + + if np.any(near_pi): + # beta≈pi: choose alpha=0, recover gamma from 2x2 block with sign flip on + # R00. + betas[near_pi] = np.pi + alphas[near_pi] = 0.0 + gammas[near_pi] = np.arctan2(R10[near_pi], -R00[near_pi]) + gammas[near_pi] = np.mod(gammas[near_pi], two_pi) + + # Unflatten back to the original batch shape + alphas = alphas.reshape(batch_shape) + betas = betas.reshape(batch_shape) + gammas = gammas.reshape(batch_shape) + return alphas, betas, gammas + + +def _l0_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=0 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=0 components to have 1 component in the last + # dimension + l0_A = torch.empty(A.shape[:-2] + (1,), dtype=A.dtype, device=A.device) + + # Compute the L=0 component as the trace of A + l0_A[..., 0] = A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2] + + l0_A = l0_A.permute(0, 2, 1) + return l0_A + + +def _l2_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=2 components from a (3, 3) tensor. + """ + # The tensor will have shape (a, 3, 3, b) so we need to move the 3, 3 dimension at + # the end + A = A.permute(0, 3, 1, 2) + # Test if the last two dimensions are (3, 3) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Initialize the output tensor for L=2 components to have 5 components in the last + # dimension + l2_A = torch.empty(A.shape[:-2] + (5,), dtype=A.dtype, device=A.device) + + l2_A[..., 0] = (A[..., 0, 1] + A[..., 1, 0]) / 2.0 + l2_A[..., 1] = (A[..., 1, 2] + A[..., 2, 1]) / 2.0 + l2_A[..., 2] = (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / ( + (2.0) * np.sqrt(3.0) + ) + l2_A[..., 3] = (A[..., 0, 2] + A[..., 2, 0]) / 2.0 + l2_A[..., 4] = (A[..., 0, 0] - A[..., 1, 1]) / 2.0 + + l2_A = l2_A.permute(0, 2, 1) + + return l2_A + + +def _euler_angles_of_combined_rotation( + angles1: Tuple[np.ndarray, np.ndarray, np.ndarray], + angles2: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Given two sets of Euler angles (alpha, beta, gamma), returns the Euler angles + of all pairwise compositions + """ + + R1 = _rotations_from_angles(*angles1).as_matrix() # (N1, 3, 3) + R2 = _rotations_from_angles(*angles2).as_matrix() # (N2, 3, 3) + + # Broadcasted pairwise multiplication to shape (N1, N2, 3, 3): R1[p] @ R2[a] + R_product = R1[:, None, :, :] @ R2[None, :, :, :] + + # Extract Euler angles from the combined rotation matrices (robust to gimbal lock) + alpha, beta, gamma = _angles_from_rotations(R_product, eps=1e-6) + return alpha, beta, gamma + + +def _get_so3_character( + alphas: np.ndarray, + betas: np.ndarray, + gammas: np.ndarray, + o3_lambda: int, + tol: float = 1e-7, +) -> np.ndarray: + """ + Numerically stable evaluation of the character function χ_{o3_lambda}(R) over SO(3). + + Uses a small-angle Taylor expansion for χ_l(ω) = sin((2l+1)t)/sin(t) with t = ω/2 + when |t| is very small, and a guarded ratio otherwise. + """ + # Compute half-angle t = ω/2 via Z–Y–Z relation: cos t = cos(β/2) cos((α+γ)/2) + cos_t = np.cos(betas / 2.0) * np.cos((alphas + gammas) / 2.0) + cos_t = np.clip(cos_t, -1.0, 1.0) + t = np.arccos(cos_t) + + # Output array + chi = np.empty_like(t) + + # Parameters for χ + L = o3_lambda + a = 2 * L + 1 + ll1 = L * (L + 1) + + small = np.abs(t) < tol + if np.any(small): + # Series up to t^4: χ ≈ a [1 - (2/3) ℓ(ℓ+1) t^2 + (1/45) ℓ(ℓ+1)(3ℓ^2+3ℓ-1) t^4] + ts = t[small] + t2 = ts * ts + coeff4 = ll1 * (3 * L * L + 3 * L - 1) + chi[small] = a * ( + 1.0 - (2.0 / 3.0) * ll1 * t2 + (1.0 / 45.0) * coeff4 * t2 * t2 + ) + + # Large-angle (or not-so-small) branch: safe ratio with guard + large = ~small + if np.any(large): + tl = t[large] + sin_t = np.sin(tl) + numer = np.sin(a * tl) + mask = np.abs(sin_t) >= tol + out = np.empty_like(tl) + np.divide(numer, sin_t, out=out, where=mask) + out[~mask] = a # exact limit as t -> 0 + chi[large] = out + + return chi + + +def compute_characters( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + inverse_angles: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Tuple[Dict[int, torch.Tensor], Dict[str, torch.Tensor]]: + alpha, beta, gamma = _euler_angles_of_combined_rotation(angles, inverse_angles) + + so3_characters = { + o3_lambda: _get_so3_character(alpha, beta, gamma, o3_lambda) + for o3_lambda in range(o3_lambda_max + 1) + } + + pso3_characters = {} + for o3_lambda in range(o3_lambda_max + 1): + for o3_sigma in [-1, +1]: + pso3_characters[f"{o3_lambda}_{o3_sigma}"] = ( + o3_sigma * ((-1) ** o3_lambda) * so3_characters[o3_lambda] + ) + + so3_characters = { + key: torch.from_numpy(value) for key, value in so3_characters.items() + } + pso3_characters = { + key: torch.from_numpy(value) for key, value in pso3_characters.items() + } + + return so3_characters, pso3_characters + + +def _character_convolution( + chi: torch.Tensor, block1: TensorBlock, block2: TensorBlock, w: torch.Tensor +) -> TensorBlock: + """ + Compute the character convolution of a block containing SO(3)-sampled tensors. + Then contract with another block. + """ + samples = block1.samples + assert samples.names[0] == "so3_rotation" + n_rot = chi.size(0) + components = block1.components + properties = block1.properties + values = block1.values + chi = chi.to(dtype=values.dtype, device=values.device) + n_rot = chi.size(1) + weight = w.to(dtype=values.dtype, device=values.device) + + split_sizes = torch.bincount(samples.values[:, 1]).tolist() + split_by_system = torch.split(values, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values = split_by_rotation + + # broadcast weights to match reshaped_values + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values.ndim - 1): + view.append(1) + weighted_values = weight.view(view) * reshaped_values + + # broadcast characters to match reshaped_values + contracted_shape: List[int] = [chi.shape[0]] + list(weighted_values.shape[1:]) + contracted_values = ( + chi @ weighted_values.reshape(weighted_values.shape[0], -1) + ).reshape(contracted_shape) + + values2 = block2.values + split_sizes = torch.bincount(block2.samples.values[:, 1]).tolist() + split_by_system = torch.split(values2, split_sizes, dim=0) + tensor_list: List[torch.Tensor] = [] + for split_tensor, size in zip(split_by_system, split_sizes, strict=True): + split_size = [size // n_rot] * n_rot + split_by_rotation = torch.stack(torch.split(split_tensor, split_size, dim=0)) + tensor_list.append(split_by_rotation) + split_by_rotation = torch.cat(tensor_list, dim=1) + reshaped_values2 = split_by_rotation + + # broadcast weights to match reshaped_values2 + view: List[int] = [] + view.append(-1) + for _ in range(reshaped_values2.ndim - 1): + view.append(1) + weighted_values2 = weight.view(view) * reshaped_values2 + + # contract weighted_values2 with contracted_values + contracted_values = torch.einsum( + "i...,i...->...", + weighted_values2, + contracted_values, + ) + + names: List[str] = [] + for name in samples.names: + if name != "so3_rotation": + names.append(name) + new_block = TensorBlock( + samples=Labels(names, samples.values[samples.values[:, 0] == 0][:, 1:]), + components=components, + properties=properties, + values=contracted_values, + ) + + return new_block + + +class SymmetrizedModel(torch.nn.Module): + """ + Wrapper around an atomistic model that symmetrizes its outputs over :math:`O(3)` + and computes equivariance metrics. + + The model is evaluated over a quadrature grid on :math:`O(3)`, constructed from a + Lebedev grid supplemented by in-plane rotations. For each sampled group element, the + model outputs are "back-rotated" according to the known :math:`O(3)` action + appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these + back-rotated predictions over the quadrature grid yields fully + :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance + metrics are computed: + + 1. Variance under :math:`O(3)` of the back-rotated outputs. + + For a perfectly equivariant model, the back-rotated output :math:`x(g)` is + independent of the group element :math:`g`. Deviations from perfect equivariance + are quantified by the difference between the average squared norm over + :math:`O(3)` and the squared norm of the :math:`O(3)`-averaged output: + + .. math:: + + \mathrm{Var}_{O(3)}[x] + = + \left\langle \,\| x(g) \|^{2} \,\right\rangle_{O(3)} + - + \left\| \left\langle x(g) \right\rangle_{O(3)} \right\|^{2} . + + Here, :math:`\|\cdot\|` denotes the Euclidean norm over the ``component`` axis, + and :math:`\langle \cdot \rangle_{O(3)}` denotes averaging over the quadrature + grid. This quantity is the squared norm of the component orthogonal to the + perfectly equivariant subspace and therefore provides a scalar measure of the + deviation from exact equivariance. + + 2. Decomposition into isotypical components of :math:`O(3)`. + + Each output component may be viewed as a scalar function on :math:`O(3)`, + which can be decomposed into isotypical components labeled by the irreducible + representations :math:`\ell,\sigma` of :math:`O(3)`. The projection onto the + :math:`(\ell,\sigma)`-th isotypical subspace is computed as a convolution with + the corresponding character :math:`\chi_{\ell}`: + + .. math:: + + (P_{\ell,\sigma} x)(g) + = + \int_{O(3)} \chi_{\ell,\sigma}(h^{-1} g)\, x(h)\, \mathrm{d}\mu(h). + + Its squared :math:`L^{2}` norm over :math:`O(3)` is + + .. math:: + + \| P_{\ell,\sigma} x \|^{2} + = + \left\langle \, | (P_{\ell,\sigma} x)(g) |^{2} \, \right\rangle_{O(3)} . + + These quantities describe how the model output is distributed across the + different :math:`O(3)` irreducible sectors. The complementary component, + orthogonal to all isotypical subspaces, is given by + + .. math:: + + \| x \|^{2} + - + \sum_{\ell,\sigma} \| P_{\ell,\sigma} x \|^{2} , + + and provides a refined measure of the deviation from lying entirely within any + prescribed set of :math:`O(3)` irreducible representations. + + :param base_model: atomistic model to symmetrize + :param max_o3_lambda: maximum O(3) angular momentum the grid integrates exactly + :param batch_size: number of rotations to evaluate in a single batch + :param max_o3_lambda_character: maximum O(3) angular momentum for character + projections. If None, set to ``max_o3_lambda``. + """ + + def __init__( + self, + base_model, + max_o3_lambda_character: int, + max_o3_lambda_target: int, + batch_size: int = 32, + max_o3_lambda_grid: Optional[int] = None, + ): + super().__init__() + self.base_model = base_model + + try: + ref_param = next(base_model.parameters()) + device = ref_param.device + dtype = ref_param.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + self.max_o3_lambda_target = max_o3_lambda_target + self.batch_size = batch_size + if max_o3_lambda_grid is None: + max_o3_lambda_grid = int(2 * max_o3_lambda_character + 1) + self.max_o3_lambda_grid = max_o3_lambda_grid + self.max_o3_lambda_character = max_o3_lambda_character + + # Compute grid (unchanged) + lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) + if lebedev_order < 2 * self.max_o3_lambda_character: + warnings.warn( + "Lebedev order may be insufficient for character projections.", + stacklevel=2, + ) + alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( + lebedev_order, n_inplane_rotations + ) + so3_weights = torch.from_numpy(w_so3).to(device=device, dtype=dtype) + self.register_buffer("so3_weights", so3_weights) + + so3_rotations = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_rotations", so3_rotations) + self.n_so3_rotations = self.so3_rotations.size(0) + + angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) + so3_inverse_rotations = torch.from_numpy( + _rotations_from_angles(*angles_inverse_rotations).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_inverse_rotations", so3_inverse_rotations) + + self._wigner_D_inverse_jit: Dict[int, torch.Tensor] = {} + self._so3_characters_jit: Dict[int, torch.Tensor] = {} + self._pso3_characters_jit: Dict[str, torch.Tensor] = {} + # Since Wigner D matrices are stored in dicts, we need a bit of gymnastics to + # register the buffers + raw_wigner = _compute_real_wigner_matrices( + self.max_o3_lambda_target, angles_inverse_rotations + ) + self._wigner_D_inverse_names: Dict[int, str] = {} + for ell, D in raw_wigner.items(): + if isinstance(D, np.ndarray): + D = torch.from_numpy(D) + D = D.to(dtype=dtype, device=device) + name = f"wigner_D_inverse_rotations_l{ell}" + self.register_buffer(name, D) + self._wigner_D_inverse_names[ell] = name + # TorchScript dict view uses the same tensor + self._wigner_D_inverse_jit[ell] = D + + # Compute characters + so3_characters, pso3_characters = compute_characters( + self.max_o3_lambda_character, + (alpha, beta, gamma), + angles_inverse_rotations, + ) + self._so3_char_names: Dict[int, str] = {} + self._pso3_char_names: Dict[str, str] = {} + + # Since characters are stored in dicts, we need a bit of gymnastics to + # register the buffers + for ell, ch in so3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"so3_characters_l{ell}" + self.register_buffer(name, ch) + self._so3_char_names[ell] = name + + self._so3_characters_jit = {} # kill the CUDA dict cache + + for ell, ch in pso3_characters.items(): + if isinstance(ch, np.ndarray): + ch = torch.from_numpy(ch) + + ch = ch.to(dtype=dtype, device="cpu") # stay on CPU + name = f"pso3_characters_l{ell}" + self.register_buffer(name, ch) + self._pso3_char_names[ell] = name + + self._pso3_characters_jit = {} + + @torch.jit.ignore + def _wigner_D_inverse_dict(self) -> Dict[int, torch.Tensor]: + return { + ell: getattr(self, name) + for ell, name in self._wigner_D_inverse_names.items() + } + + @property + def wigner_D_inverse_rotations(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._wigner_D_inverse_dict() + + @torch.jit.ignore + def _so3_characters_dict(self) -> Dict[int, torch.Tensor]: + return {ell: getattr(self, name) for ell, name in self._so3_char_names.items()} + + @property + def so3_characters(self) -> Dict[int, torch.Tensor]: + # Python-only nice view + return self._so3_characters_dict() + + @torch.jit.ignore + def _pso3_characters_dict(self) -> Dict[str, torch.Tensor]: + return {key: getattr(self, name) for key, name in self._pso3_char_names.items()} + + @property + def pso3_characters(self) -> Dict[str, torch.Tensor]: + # Python-only nice view + return self._pso3_characters_dict() + + def _get_wigner_D_inverse(self, ell: int) -> torch.Tensor: + return self._wigner_D_inverse_jit[ell] + + def _get_so3_character(self, o3_lambda: int) -> torch.Tensor: + name = self._so3_char_names[o3_lambda] + ch_cpu = getattr(self, name) + + # follow the base model device/dtype + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) + + def _get_pso3_character(self, o3_lambda: int, o3_sigma: int) -> torch.Tensor: + label = str(o3_lambda) + "_" + str(o3_sigma) + name = self._pso3_char_names[label] + ch_cpu = getattr(self, name) + + try: + ref = next(self.base_model.parameters()) + device = ref.device + dtype = ref.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + return ch_cpu.to(device=device, dtype=dtype, non_blocking=True) + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + project_tokens: bool = False, + ) -> Dict[str, TensorMap]: + """ + Symmetrize the model outputs over :math:`O(3)` and compute equivariance + metrics. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :return: dictionary with symmetrized outputs and equivariance metrics + """ + # Evaluate the model over the grid + transformed_outputs, backtransformed_outputs = self._eval_over_grid( + systems, outputs, selected_atoms, return_transformed=project_tokens, + ) + + transformed_outputs = self._decompose_tensors(transformed_outputs) + backtransformed_outputs = self._decompose_tensors(backtransformed_outputs) + + out_dict: Dict[str, TensorMap] = {} + + # Compute the O(3) mean and variance + mean_var = self._compute_mean_and_variance(backtransformed_outputs) + for name, tensor in mean_var.items(): + out_dict[name] = tensor + + if not project_tokens: + return out_dict + + # Compute norms + norms = self._compute_norm_per_property(transformed_outputs) + for name, tensor in norms.items(): + out_dict[name] = tensor + + # Compute the character projections + convolution_integrals = self._compute_conv_integral(transformed_outputs) + for name, integral in convolution_integrals.items(): + out_dict[name] = integral + + return out_dict + + def _decompose_tensors( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose tensors in the dictionary into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed tensors + """ + tensor_dict = self._decompose_energy_tensor(tensor_dict) + tensor_dict = self._decompose_forces_tensor(tensor_dict) + tensor_dict = self._decompose_stress_tensor(tensor_dict) + return tensor_dict + + def _decompose_energy_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose energy tensor into irreducible representations of O(3). + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed energy tensors + """ + if "energy" not in tensor_dict: + return tensor_dict + + tensor = tensor_dict["energy"] + tensor_dict["energy_l0"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.unsqueeze(1), + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=self.so3_weights.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop("energy") + return tensor_dict + + def _decompose_forces_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose forces tensor into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed forces tensors + """ + if "non_conservative_forces" not in tensor_dict: + return tensor_dict + + tensor = tensor_dict["non_conservative_forces"] + tensor_dict["non_conservative_forces_l1"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.roll(-1, 1), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-1, 2)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop("non_conservative_forces") + return tensor_dict + + def _decompose_stress_tensor( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Decompose stress tensor into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed stress tensors + """ + if "non_conservative_stress" not in tensor_dict: + return tensor_dict + + tensor = tensor_dict["non_conservative_stress"] + blocks_l0 = [] + blocks_l2 = [] + for block in tensor.blocks(): + trace_values = _l0_components_from_matrices(block.values) + block_l0 = TensorBlock( + values=trace_values, + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + blocks_l0.append(block_l0) + + block_l2 = TensorBlock( + values=_l2_components_from_matrices(block.values), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-2, 3)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + blocks_l2.append(block_l2) + + tensor_dict["non_conservative_stress_l0"] = TensorMap(tensor.keys, blocks_l0) + tensor_dict["non_conservative_stress_l2"] = TensorMap(tensor.keys, blocks_l2) + tensor_dict.pop("non_conservative_stress") + + return tensor_dict + + def _compute_norm_per_property( + self, tensor_dict: Dict[str, TensorMap] + ) -> Dict[str, TensorMap]: + """ + Compute the norm per property of each tensor in ``tensor_dict``. + + :param tensor_dict: dictionary of TensorMaps to compute norms for + :return: dictionary of TensorMaps with norms per property + """ + norms: Dict[str, TensorMap] = {} + for name in tensor_dict: + tensor = tensor_dict[name] + norm_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values_squared = block.values**2 + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = ( + 0.5 * self.so3_weights[rot_ids].view(view) * values_squared + ) + + norm_blocks.append( + TensorBlock( + values=values_squared, # /(8 * torch.pi**2), + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + + tensor_norm = TensorMap(tensor.keys, norm_blocks) + tensor_norm = mts.sum_over_samples( + tensor_norm.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + norms[name + "_componentwise_norm_squared"] = tensor_norm + return norms + + def _compute_conv_integral( + self, tensor_dict: Dict[str, TensorMap] + ) -> Dict[str, TensorMap]: + """ + Compute the O(3)-convolution of each tensor in ``tensor_dict`` with O(3) + characters. + + :param tensor_dict: dictionary of TensorMaps to compute convolution integral for + :return: dictionary of TensorMaps with convolution integrals + """ + + new_tensors: Dict[str, TensorMap] = {} + # loop over tensormaps + for name, tensor in tensor_dict.items(): + keys = tensor.keys + remaining_keys = Labels( + keys.names[:-1], + keys.values[keys.column("inversion") == 1][:, :-1], + ) + new_blocks: List[TensorBlock] = [] + new_keys: List[torch.Tensor] = [] + # loop over keys in the final tensormap + for key_values in remaining_keys.values: + key_to_match_plus: Dict[str, int] = {} + key_to_match_minus: Dict[str, int] = {} + for k, v in zip(remaining_keys.names, key_values, strict=True): + key_to_match_plus[k] = int(v) + key_to_match_minus[k] = int(v) + key_to_match_plus["inversion"] = 1 + key_to_match_minus["inversion"] = -1 + # get the corresponding blocks for proper and improper rotations + so3_block = tensor.block(key_to_match_plus) + pso3_block = tensor.block(key_to_match_minus) + + # loop over SO(3) irreps + for o3_lambda in range(self.max_o3_lambda_character + 1): + so3_chi = self._get_so3_character(o3_lambda) + first_term = _character_convolution( + so3_chi, so3_block, so3_block, self.so3_weights + ) + second_term = _character_convolution( + so3_chi, pso3_block, pso3_block, self.so3_weights + ) + for o3_sigma in [1, -1]: + pso3_chi = self._get_pso3_character(o3_lambda, o3_sigma) + third_term = _character_convolution( + pso3_chi, pso3_block, so3_block, self.so3_weights + ) + block = TensorBlock( + samples=first_term.samples, + components=first_term.components, + properties=first_term.properties, + values=( + 0.25 * (first_term.values + second_term.values) + + 0.5 * third_term.values + ) + * (2 * o3_lambda + 1), + # / (8 * torch.pi**2) ** 2, + ) + new_blocks.append(block) + new_keys.append( + torch.cat( + [ + key_values, + torch.tensor( + [o3_lambda, o3_sigma], + device=key_values.device, + dtype=key_values.dtype, + ), + ] + ) + ) + key_names: List[str] = [] + for key_name in tensor.keys.names: + if key_name != "inversion": + key_names.append(key_name) + new_tensor = TensorMap( + Labels( + key_names + ["chi_lambda", "chi_sigma"], + torch.stack(new_keys), + ), + new_blocks, + ) + if "_" in new_tensor.keys.names: + new_tensor = mts.remove_dimension(new_tensor, "keys", "_") + new_tensors[name + "_character_projection"] = new_tensor + return new_tensors + + def _eval_over_grid( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + return_transformed: bool, + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Sample the model on the O(3) quadrature. + + :param systems: list of systems to evaluate + :param model: atomistic model to evaluate + :param device: device to use for computation + :return: list of list of model outputs, shape (len(systems), N) + where N is the number of quadrature points + """ + + # Evaluate the model over the grid + results = evaluate_model_over_grid( + self.base_model, + self.batch_size, + self.so3_rotations, + self.so3_inverse_rotations, + self._wigner_D_inverse_jit, + return_transformed, + systems, + outputs, + selected_atoms, + ) + + if return_transformed: + transformed_outputs_tensor, backtransformed_outputs_tensor = results + else: + backtransformed_outputs_tensor = results + transformed_outputs_tensor: Dict[str, TensorMap] = {} + + # TODO: possibly remove + if "energy" in transformed_outputs_tensor: + energy_tm = transformed_outputs_tensor["energy"] + if "atom" in energy_tm[0].samples.names: + # Sum over atoms while keeping system and rotation indices. + energy_total_tm = mts.sum_over_samples(energy_tm, ["atom"]) + transformed_outputs_tensor["energy_total"] = energy_total_tm + + if "energy" in backtransformed_outputs_tensor: + energy_tm_bt = backtransformed_outputs_tensor["energy"] + if "atom" in energy_tm_bt[0].samples.names: + energy_total_tm_bt = mts.sum_over_samples(energy_tm_bt, ["atom"]) + backtransformed_outputs_tensor["energy_total"] = energy_total_tm_bt + return transformed_outputs_tensor, backtransformed_outputs_tensor + + def _compute_mean_and_variance( + self, + tensor_dict: Dict[str, TensorMap], + ) -> Dict[str, TensorMap]: + """ + Compute the mean and variance of the outputs over O(3). + + :param tensor_dict: dictionary of TensorMaps to compute mean and variance for + :return: dictionary of TensorMaps with mean and variance + """ + + return symmetrize_over_grid(tensor_dict, self.so3_weights) + + +def evaluate_model_over_grid( + model: ModelInterface, + batch_size: int, + so3_rotations: torch.Tensor, + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], + return_transformed: bool, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, +) -> Dict[str, TensorMap] | Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Sample the model on the O(3) quadrature. + + :param model: atomistic model to evaluate + :param systems: list of systems to evaluate + :param so3_rotations: SO(3) rotation matrices to use for sampling + :param batch_size: number of rotations to evaluate in a single batch + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :return: model outputs evaluated on the rotated systems. It is a dictionary + where each key is a target name and each value is a list of length len(systems). + Each element of the list is a dictionary with keys -1 and 1 (for improper + and proper rotations) and values the corresponding TensorMap outputs of the + model on the SO(3) quadrature. + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + transformed_outputs: Dict[str, List[Dict[int, TensorMap]]] = {} + for name in outputs: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + transformed_outputs[name] = lst + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + rotation_outputs: List[Dict[str, TensorMap]] = [] + for batch in range(0, len(so3_rotations), batch_size): + transformed_systems = [ + _transform_system( + system, inversion * R.to(device=device, dtype=dtype) + ) + for R in so3_rotations[batch : batch + batch_size] + ] + with torch.no_grad(): + out = model( + transformed_systems, + outputs, + selected_atoms, + ) + rotation_outputs.append(out) + + # Combine batch outputs + for name in outputs: + combined_: List[TensorMap] = [r[name] for r in rotation_outputs] + combined = mts.join( + combined_, + "samples", + add_dimension="batch_rotation", + ) + if "batch_rotation" in combined[0].samples.names: + # Reindex + blocks: List[TensorBlock] = [] + for block in combined.blocks(): + batch_id = block.samples.column("batch_rotation") + rot_id = block.samples.column("system") + new_sample_values = block.samples.values[:, :-1] + new_sample_values[:, 0] = batch_id * batch_size + rot_id + blocks.append( + TensorBlock( + values=block.values.detach(), + samples=Labels( + block.samples.names[:-1], + new_sample_values, + ), + components=block.components, + properties=block.properties, + ) + ) + combined = TensorMap(combined.keys, blocks) + transformed_outputs[name][i_sys][inversion] = combined + + backtransformed_outputs = _backtransform_outputs( + transformed_outputs, systems, so3_rotations_inverse, wigner_D_inverse + ) + backtransformed_outputs_tensor = _to_metatensor(backtransformed_outputs, systems) + + if return_transformed: + transformed_outputs_tensor = _to_metatensor(transformed_outputs, systems) + return transformed_outputs_tensor, backtransformed_outputs_tensor + else: + transformed_outputs_tensor: Dict[str, TensorMap] = {} + return backtransformed_outputs_tensor + + +def _to_metatensor( + tensor_dict: Dict[str, TensorMap], systems: List[System] +) -> Dict[str, TensorMap]: + """ + Convert the outputs of the model evaluated on rotated systems to a single + TensorMap per property, with appropriate dimensions for O(3) symmetrization. + """ + + out_tensor_dict: Dict[str, TensorMap] = {} + # Massage outputs to have desired shape + for name in tensor_dict: + joined_plus = mts.join( + [tensor_dict[name][i_sys][1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined_minus = mts.join( + [tensor_dict[name][i_sys][-1] for i_sys in range(len(systems))], + "samples", + add_dimension="phys_system", + ) + joined = mts.join( + [ + mts.append_dimension(joined_plus, "keys", "inversion", 1), + mts.append_dimension(joined_minus, "keys", "inversion", -1), + ], + "samples", + different_keys="union", + ) + joined = mts.rename_dimension(joined, "samples", "system", "so3_rotation") + + if "phys_system" in joined[0].samples.names: + joined = mts.rename_dimension(joined, "samples", "phys_system", "system") + else: + joined = mts.insert_dimension( + joined, + "samples", + 1, + "system", + torch.zeros( + joined[0].samples.values.shape[0], + dtype=torch.long, + device=joined[0].samples.values.device, + ), + ) + if "atom" in joined[0].samples.names or "first_atom" in joined[0].samples.names: + perm = _permute_system_before_atom(joined[0].samples.names) + joined = mts.permute_dimensions(joined, "samples", perm) + out_tensor_dict[name] = joined + + return out_tensor_dict + + +def _backtransform_outputs( + tensor_dict: Dict[str, List[Dict[int, TensorMap]]], + systems: List[System], + so3_rotations_inverse: torch.Tensor, + wigner_D_inverse: Dict[int, torch.Tensor], +) -> Dict[str, List[Dict[int, TensorMap]]]: + """ + Given the outputs of the model evaluated on rotated systems, backtransform them to + the original frame according to the equivariance labels in the TensorMap keys. + """ + + device = systems[0].positions.device + dtype = systems[0].positions.dtype + + backtransformed_tensor_dict: Dict[str, List[Dict[int, TensorMap]]] = {} + for name in tensor_dict: + lst: List[Dict[int, TensorMap]] = [] + for _ in systems: + d: Dict[int, TensorMap] = {} + lst.append(d) + backtransformed_tensor_dict[name] = lst + + n_rot = so3_rotations_inverse.size(0) + for name in tensor_dict: + for i_sys, system in enumerate(systems): + for inversion in [-1, 1]: + tensor = tensor_dict[name][i_sys][inversion] + wigner_dict: Dict[int, List[torch.Tensor]] = {} + for ell in wigner_D_inverse: + wigner_dict[ell] = ( + wigner_D_inverse[ell].to(device=device, dtype=dtype).unbind(0) + ) + + _, backtransformed, _ = _apply_augmentations( + [system] * n_rot, + {name: tensor}, + list( + ( + so3_rotations_inverse.to(device=device, dtype=dtype) + * inversion + ).unbind(0) + ), + wigner_dict, + ) + backtransformed_tensor_dict[name][i_sys][inversion] = backtransformed[ + name + ] + return backtransformed_tensor_dict + + +def _permute_system_before_atom(labels: List[str]) -> List[int]: + # find positions + sys_idx = -1 + atom_idx = -1 + for i in range(len(labels)): + if labels[i] == "system": + sys_idx = i + elif labels[i] == "atom": + atom_idx = i + elif labels[i] == "first_atom": + atom_idx = i + + # identity permutation + perm = list(range(len(labels))) + + # reorder only if both present and system is after atom + if sys_idx != -1 and atom_idx != -1 and sys_idx > atom_idx: + v = perm[sys_idx] + # remove system + for k in range(sys_idx, len(perm) - 1): + perm[k] = perm[k + 1] + perm.pop() + # insert before atom + perm.insert(atom_idx, v) + + return perm + + +def symmetrize_over_grid( + tensor_dict: Dict[str, TensorMap], + so3_weights: torch.Tensor, +) -> Dict[str, TensorMap]: + """ + Compute the mean and variance of the outputs over O(3). + + :param tensor_dict: dictionary of TensorMaps with rotated and backtransformed + outputs to compute mean, variance, and norm squared for + :param so3_weights: weights of the SO(3) quadrature + :return: dictionary of TensorMaps with mean, variance, and norm squared + """ + mean_var: Dict[str, TensorMap] = {} + for name in tensor_dict: + # cannot compute a mean or variance as these have no known behaviour under + # rotations + if "features" in name: + continue + tensor = tensor_dict[name] + mean_blocks: List[TensorBlock] = [] + second_moment_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + rot_ids = block.samples.column("so3_rotation") + + values = block.values + if values.ndim > 2: + dims: List[int] = [] + for i in range(1, values.ndim - 1): + dims.append(i) + values_squared = torch.sum(values**2, dim=dims) + else: + values_squared = values**2 + + view: List[int] = [] + view.append(values.size(0)) + for _ in range(values.ndim - 1): + view.append(1) + values = 0.5 * so3_weights[rot_ids].view(view) * values + + view: List[int] = [] + view.append(values_squared.size(0)) + for _ in range(values_squared.ndim - 1): + view.append(1) + values_squared = 0.5 * so3_weights[rot_ids].view(view) * values_squared + + mean_blocks.append( + TensorBlock( + values=values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + second_moment_blocks.append( + TensorBlock( + values=values_squared, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + + # Mean + tensor_mean = TensorMap(tensor.keys, mean_blocks) + tensor_mean = mts.sum_over_samples( + tensor_mean.keys_to_samples("inversion"), ["inversion", "so3_rotation"] + ) + + # Mean norm + mean_norm_squared_blocks: List[TensorBlock] = [] + for block in tensor_mean.blocks(): + vals = block.values + if vals.ndim > 2: + dims: List[int] = [] + for i in range(1, vals.ndim - 1): + dims.append(i) + vals = torch.sum(vals**2, dim=dims) + else: + vals = vals**2 + mean_norm_squared_blocks.append( + TensorBlock( + values=vals, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + tensor_mean_norm_squared = TensorMap(tensor_mean.keys, mean_norm_squared_blocks) + + # Second moment + tensor_second_moment = TensorMap(tensor.keys, second_moment_blocks) + tensor_second_moment = mts.sum_over_samples( + tensor_second_moment.keys_to_samples("inversion"), + ["inversion", "so3_rotation"], + ) + + # Variance + tensor_variance = mts.subtract(tensor_second_moment, tensor_mean_norm_squared) + + mean_var[name + "_mean"] = tensor_mean + mean_var[name + "_norm_squared"] = tensor_second_moment + mean_var[name + "_var"] = tensor_variance + return mean_var