diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9528c632..12c8887d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -167,7 +167,7 @@ jobs: - name: Find example scripts id: set-matrix run: | - EXAMPLES=$(find examples -name "*.py" | jq -R -s -c 'split("\n")[:-1]') + EXAMPLES=$(find examples -name "*.py" ! -path "examples/scaling/*" | jq -R -s -c 'split("\n")[:-1]') echo "examples=$EXAMPLES" >> $GITHUB_OUTPUT test-examples: diff --git a/examples/scaling/scaling_nve.py b/examples/scaling/scaling_nve.py new file mode 100644 index 00000000..a0dbc5af --- /dev/null +++ b/examples/scaling/scaling_nve.py @@ -0,0 +1,77 @@ +"""Scaling for TorchSim NVE.""" +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace,test]" +# ] +# /// + +import time +import typing + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +N_STRUCTURES = [1, 1, 1, 10, 100, 500, 1000, 1500, 5000, 10000] + + +MD_STEPS = 10 + + +def run_torchsim_nve( + n_structures_list: list[int], + base_structure: typing.Any, +) -> list[float]: + """Load model, run NVE MD for MD_STEPS per n; return times.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype="float64", + device=str(device), + ) + max_memory_scaler = 400_000 + memory_scales_with = "n_atoms_x_density" + model = MaceModel( + model=typing.cast("torch.nn.Module", loaded_model), + device=device, + compute_forces=True, + compute_stress=True, + dtype=torch.float64, + enable_cueq=False, + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + t0 = time.perf_counter() + ts.integrate( + system=structures, + model=model, + integrator=ts.Integrator.nve, + n_steps=MD_STEPS, + temperature=300.0, + timestep=0.002, + autobatcher=ts.BinningAutoBatcher( + model=model, + max_memory_scaler=max_memory_scaler, + memory_scales_with=memory_scales_with, + ), + ) + if device.type == "cuda": + torch.cuda.empty_cache() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} nve_time={elapsed:.6f}s") + return times + + +if __name__ == "__main__": + mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True) + base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase) + sweep_totals = run_torchsim_nve(N_STRUCTURES, base_structure=base_structure) diff --git a/examples/scaling/scaling_nvt.py b/examples/scaling/scaling_nvt.py new file mode 100644 index 00000000..66b28d8a --- /dev/null +++ b/examples/scaling/scaling_nvt.py @@ -0,0 +1,77 @@ +"""Scaling for TorchSim NVT (Nose-Hoover).""" +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace,test]" +# ] +# /// + +import time +import typing + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +N_STRUCTURES = [1, 1, 1, 10, 100, 500, 1000, 1500, 5000, 10000] + + +MD_STEPS = 10 + + +def run_torchsim_nvt( + n_structures_list: list[int], + base_structure: typing.Any, +) -> list[float]: + """Load model, run NVT MD for MD_STEPS per n; return times.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype="float64", + device=str(device), + ) + max_memory_scaler = 400_000 + memory_scales_with = "n_atoms_x_density" + model = MaceModel( + model=typing.cast("torch.nn.Module", loaded_model), + device=device, + compute_forces=True, + compute_stress=True, + dtype=torch.float64, + enable_cueq=False, + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + t0 = time.perf_counter() + ts.integrate( + system=structures, + model=model, + integrator=ts.Integrator.nvt_nose_hoover, + n_steps=MD_STEPS, + temperature=300.0, + timestep=0.002, + autobatcher=ts.BinningAutoBatcher( + model=model, + max_memory_scaler=max_memory_scaler, + memory_scales_with=memory_scales_with, + ), + ) + if device.type == "cuda": + torch.cuda.empty_cache() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} nvt_time={elapsed:.6f}s") + return times + + +if __name__ == "__main__": + mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True) + base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase) + sweep_totals = run_torchsim_nvt(N_STRUCTURES, base_structure=base_structure) diff --git a/examples/scaling/scaling_relax.py b/examples/scaling/scaling_relax.py new file mode 100644 index 00000000..617f00d3 --- /dev/null +++ b/examples/scaling/scaling_relax.py @@ -0,0 +1,86 @@ +"""Scaling for TorchSim relax.""" +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace,test]" +# ] +# /// + +import time +import typing + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +N_STRUCTURES = [1, 1, 1, 10, 100, 500, 1000, 1500] + + +RELAX_STEPS = 10 + + +def run_torchsim_relax( + n_structures_list: list[int], + base_structure: typing.Any, +) -> list[float]: + """Load TorchSim model once, run 10-step relaxation with ts.optimize for each n; + return timings. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype="float64", + device=str(device), + ) + model = MaceModel( + model=typing.cast("torch.nn.Module", loaded_model), + device=device, + compute_forces=True, + compute_stress=True, + dtype=torch.float64, + enable_cueq=False, + ) + autobatcher = ts.InFlightAutoBatcher( + model=model, + max_memory_scaler=400_000, + memory_scales_with="n_atoms_x_density", + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + t0 = time.perf_counter() + ts.optimize( + system=structures, + model=model, + optimizer=ts.optimizers.Optimizer.fire, + init_kwargs={ + "cell_filter": ts.optimizers.cell_filters.CellFilter.frechet, + "constant_volume": False, + "hydrostatic_strain": True, + }, + max_steps=RELAX_STEPS, + convergence_fn=ts.runners.generate_force_convergence_fn( + force_tol=1e-3, + include_cell_forces=True, + ), + autobatcher=autobatcher, + ) + if device.type == "cuda": + torch.cuda.synchronize() + torch.cuda.empty_cache() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} relax_{RELAX_STEPS}_time={elapsed:.6f}s") + return times + + +if __name__ == "__main__": + mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True) + base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase) + sweep_totals = run_torchsim_relax(N_STRUCTURES, base_structure) diff --git a/examples/scaling/scaling_static.py b/examples/scaling/scaling_static.py new file mode 100644 index 00000000..681553e2 --- /dev/null +++ b/examples/scaling/scaling_static.py @@ -0,0 +1,71 @@ +"""Scaling for TorchSim static.""" +# %% +# /// script +# dependencies = [ +# "torch_sim_atomistic[mace,test]" +# ] +# /// + +import time +import typing + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp +from pymatgen.io.ase import AseAtomsAdaptor + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls + + +N_STRUCTURES = [1, 1, 1, 10, 100, 500, 1000, 1500, 5000, 10000, 50000, 100000] + + +def run_torchsim_static( + n_structures_list: list[int], + base_structure: typing.Any, +) -> list[float]: + """Load TorchSim model once, run static for each n using O(1) + batched path, return timings. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + loaded_model = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype="float64", + device=str(device), + ) + model = MaceModel( + model=typing.cast("torch.nn.Module", loaded_model), + device=device, + compute_forces=True, + compute_stress=True, + dtype=torch.float64, + enable_cueq=False, + ) + batcher = ts.BinningAutoBatcher( + model=model, + max_memory_scaler=400_000, + memory_scales_with="n_atoms_x_density", + ) + times: list[float] = [] + for n in n_structures_list: + structures = [base_structure] * n + t0 = time.perf_counter() + state = ts.initialize_state(structures, model.device, model.dtype) + batcher.load_states(state) + for sub_state, _ in batcher: + model(sub_state) + if device.type == "cuda": + torch.cuda.synchronize() + torch.cuda.empty_cache() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + print(f" n={n} static_time={elapsed:.6f}s") + return times + + +if __name__ == "__main__": + mgo_ase = bulk(name="MgO", crystalstructure="rocksalt", a=4.21, cubic=True) + base_structure = AseAtomsAdaptor.get_structure(atoms=mgo_ase) + sweep_totals = run_torchsim_static(N_STRUCTURES, base_structure) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7ec870d3..c016f88a 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -95,12 +95,14 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None: # Test n_atoms metric n_atoms_metric = calculate_memory_scaler(si_sim_state, "n_atoms") assert n_atoms_metric == si_sim_state.n_atoms + assert n_atoms_metric == 8 # Test n_atoms_x_density metric density_metric = calculate_memory_scaler(si_sim_state, "n_atoms_x_density") volume = torch.abs(torch.linalg.det(si_sim_state.cell[0])) / 1000 expected = si_sim_state.n_atoms * (si_sim_state.n_atoms / volume.item()) assert pytest.approx(density_metric, rel=1e-5) == expected + assert pytest.approx(density_metric, rel=1e-5) == (8**2) * 1000 / (5.43**3) # Test invalid metric with pytest.raises(ValueError, match="Invalid metric"): @@ -109,15 +111,24 @@ def test_calculate_scaling_metric(si_sim_state: ts.SimState) -> None: def test_calculate_scaling_metric_non_periodic(benzene_sim_state: ts.SimState) -> None: """Test calculation of scaling metrics for a non-periodic state.""" - # Test that calculate passes n_atoms_metric = calculate_memory_scaler(benzene_sim_state, "n_atoms") assert n_atoms_metric == benzene_sim_state.n_atoms + assert n_atoms_metric == 12 - # Test n_atoms_x_density metric works for non-periodic systems n_atoms_x_density_metric = calculate_memory_scaler( benzene_sim_state, "n_atoms_x_density" ) assert n_atoms_x_density_metric > 0 + bbox = ( + benzene_sim_state.positions.max(dim=0).values + - benzene_sim_state.positions.min(dim=0).values + ).clone() + for i, p in enumerate(benzene_sim_state.pbc): + if not p: + bbox[i] += 2.0 + assert pytest.approx(n_atoms_x_density_metric, rel=1e-5) == ( + benzene_sim_state.n_atoms**2 / (bbox.prod().item() / 1000) + ) def test_split_state(si_double_sim_state: ts.SimState) -> None: diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index a0cf6aab..70846652 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -322,6 +322,41 @@ def determine_max_batch_size( return sizes[-1] +def calculate_batched_memory_scalers( + state: SimState, + memory_scales_with: MemoryScaling = "n_atoms_x_density", +) -> list[float]: + """Compute memory scalers for all systems in a batched state in O(n_systems). + + When all systems are periodic and memory_scales_with is "n_atoms_x_density", + uses only state.n_atoms_per_system and state.volume (no split, no position + scan). For non-periodic systems, splits and uses calculate_memory_scaler per + system so that bbox-based volume is used consistently. + + Args: + state: Batched SimState (n_systems >= 1). + memory_scales_with: Same as in calculate_memory_scaler. + + Returns: + list[float]: One scaler per system, length state.n_systems. + """ + if memory_scales_with == "n_atoms": + return state.n_atoms_per_system.tolist() + if memory_scales_with == "n_atoms_x_density": + if not state.pbc.all().item(): + return [calculate_memory_scaler(s, memory_scales_with) for s in state.split()] + n_per = state.n_atoms_per_system.to(state.volume.dtype) + vol_nm3 = torch.abs(state.volume) / 1000 # A^3 -> nm^3 + # Where volume <= 0, use n_atoms (degenerate cell) + safe_vol = torch.where(vol_nm3 > 0, vol_nm3, torch.ones_like(vol_nm3)) + scalers = n_per * (n_per / safe_vol) + scalers = torch.where(vol_nm3 > 0, scalers, n_per) + return scalers.tolist() + raise ValueError( + f"Invalid metric: {memory_scales_with}, must be one of {get_args(MemoryScaling)}" + ) + + def calculate_memory_scaler( state: SimState, memory_scales_with: MemoryScaling = "n_atoms_x_density", @@ -556,11 +591,17 @@ def load_states(self, states: T | Sequence[T]) -> float: This method resets the current state bin index, so any ongoing iteration will be restarted when this method is called. """ - self.state_slices = states.split() if isinstance(states, SimState) else states - self.memory_scalers = [ - calculate_memory_scaler(state_slice, self.memory_scales_with) - for state_slice in self.state_slices - ] + if isinstance(states, SimState): + self.memory_scalers = calculate_batched_memory_scalers( + states, self.memory_scales_with + ) + self.state_slices = states.split() + else: + self.state_slices = states + self.memory_scalers = [ + calculate_memory_scaler(s, self.memory_scales_with) + for s in self.state_slices + ] if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( self.state_slices, @@ -589,9 +630,14 @@ def load_states(self, states: T | Sequence[T]) -> float: ) # list[dict[original_index: int, memory_scale:float]] # Convert to list of lists of indices self.index_bins = [list(batch.keys()) for batch in self.index_bins] - self.batched_states = [] - for index_bin in self.index_bins: - self.batched_states.append([self.state_slices[idx] for idx in index_bin]) + # Build batches: one sliced state per bin + if isinstance(states, SimState): + self.batched_states = [[states[index_bin]] for index_bin in self.index_bins] + else: + self.batched_states = [ + [self.state_slices[idx] for idx in index_bin] + for index_bin in self.index_bins + ] self.current_state_bin = 0 return self.max_memory_scaler @@ -846,7 +892,7 @@ def load_states(self, states: Sequence[T] | Iterator[T] | T) -> None: """ if isinstance(states, SimState): states = states.split() - if isinstance(states, list | tuple): + if not isinstance(states, Iterator): states = iter(states) self.states_iterator = states diff --git a/torch_sim/state.py b/torch_sim/state.py index ec2c589d..9ce96324 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -10,7 +10,7 @@ from collections import defaultdict from collections.abc import Generator, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, overload import torch @@ -430,16 +430,71 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: - """Split the SimState into a list of single-system SimStates. + def split(self) -> Sequence[Self]: # noqa: C901 + """Split the SimState into a sequence of single-system SimStates (O(1)). - Divides the current state into separate states, each containing a single system, - preserving all properties appropriately for each system. + Each single-system state is created on first access (index or iteration), + so the call itself is O(1). Use like a list: len(s), s[i], for x in s. Returns: - list[SimState]: A list of SimState objects, one per system + Sequence[SimState]: A sequence of SimState objects, one per system """ - return _split_state(self) + state = self + + def _get_system_slice(s: Self, sys_idx: int) -> Self: + n_systems = s.n_systems + if sys_idx < 0 or sys_idx >= n_systems: + raise IndexError(f"system index {sys_idx} out of range [0, {n_systems})") + cumsum_atoms = torch.cat( + (s.n_atoms_per_system.new_zeros(1), s.n_atoms_per_system.cumsum(0)) + ) + start = cumsum_atoms[sys_idx].item() + end = cumsum_atoms[sys_idx + 1].item() + n_atoms_i = end - start + + system_attrs: dict[str, Any] = { + "system_idx": torch.zeros(n_atoms_i, device=s.device, dtype=torch.int64), + **dict(get_attrs_for_scope(s, "global")), + } + for attr_name, attr_value in get_attrs_for_scope(s, "per-atom"): + if attr_name != "system_idx": + system_attrs[attr_name] = attr_value[start:end] + for attr_name, attr_value in get_attrs_for_scope(s, "per-system"): + if isinstance(attr_value, torch.Tensor): + system_attrs[attr_name] = attr_value[sys_idx : sys_idx + 1] + else: + system_attrs[attr_name] = attr_value + + atom_idx = torch.arange(start, end, device=s.device) + new_constraints = [ + new_constraint + for constraint in s.constraints + if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) + ] + system_attrs["_constraints"] = new_constraints + return type(s)(**system_attrs) # type: ignore[invalid-argument-type] + + class _SplitSeq(Sequence[Self]): + def __len__(self) -> int: + return state.n_systems + + @overload + def __getitem__(self, i: int) -> Self: ... + + @overload + def __getitem__(self, s: slice) -> list[Self]: ... + + def __getitem__(self, i: int | slice) -> Self | list[Self]: + if isinstance(i, slice): + start, stop, step = i.indices(len(self)) + return [_get_system_slice(state, j) for j in range(start, stop, step)] + return _get_system_slice(state, i) + + def __iter__(self) -> Generator[Self, None, None]: + for j in range(len(self)): + yield _get_system_slice(state, j) + + return _SplitSeq() # type: ignore[return-value] def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: """Pop off states with the specified system indices. @@ -813,73 +868,6 @@ def _filter_attrs_by_mask( return filtered_attrs -def _split_state[T: SimState](state: T) -> list[T]: - """Split a SimState into a list of states, each containing a single system. - - Divides a multi-system state into individual single-system states, preserving - appropriate properties for each system. - - Args: - state (SimState): The SimState to split - - Returns: - list[SimState]: A list of SimState objects, each containing a single - system - """ - system_sizes = state.n_atoms_per_system.tolist() - - split_per_atom = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - - split_per_system = {} - for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - if isinstance(attr_value, torch.Tensor): - split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) - else: # Non-tensor attributes are replicated for each split - split_per_system[attr_name] = [attr_value] * state.n_systems - - global_attrs = dict(get_attrs_for_scope(state, "global")) - - # Create a state for each system - states: list[T] = [] - n_systems = len(system_sizes) - zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) - cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) - for sys_idx in range(n_systems): - system_attrs = { - # Create a system tensor with all zeros for this system - "system_idx": torch.zeros( - system_sizes[sys_idx], device=state.device, dtype=torch.int64 - ), - # Add the split per-atom attributes - **{ - attr_name: split_per_atom[attr_name][sys_idx] - for attr_name in split_per_atom - }, - # Add the split per-system attributes - **{ - attr_name: split_per_system[attr_name][sys_idx] - for attr_name in split_per_system - }, - # Add the global attributes - **global_attrs, - } - - atom_idx = torch.arange(cumsum_atoms[sys_idx], cumsum_atoms[sys_idx + 1]) - new_constraints = [ - new_constraint - for constraint in state.constraints - if (new_constraint := constraint.select_sub_constraint(atom_idx, sys_idx)) - ] - - system_attrs["_constraints"] = new_constraints - states.append(type(state)(**system_attrs)) # type: ignore[invalid-argument-type] - - return states - - def _pop_states[T: SimState]( state: T, pop_indices: list[int] | torch.Tensor ) -> tuple[T, list[T]]: @@ -922,7 +910,7 @@ def _pop_states[T: SimState]( # Create and split the pop state pop_state: T = type(state)(**pop_attrs) # type: ignore[assignment] - pop_states = _split_state(pop_state) + pop_states = list(pop_state.split()) return keep_state, pop_states @@ -931,12 +919,12 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor """Slice a substate from the SimState containing only the specified system indices. Creates a new SimState containing only the specified systems, preserving - all relevant properties. + their requested order (not natural 0,1,2 order). Args: state (SimState): The state to slice system_indices (list[int] | torch.Tensor): System indices to include in the - sliced state + sliced state (order preserved in the result) Returns: SimState: A new SimState object containing only the specified systems @@ -952,15 +940,52 @@ def _slice_state[T: SimState](state: T, system_indices: list[int] | torch.Tensor if len(system_indices) == 0: raise ValueError("system_indices cannot be empty") - # Create masks for the atoms and systems to include - system_range = torch.arange(state.n_systems, device=state.device) - system_mask = torch.isin(system_range, system_indices) - atom_mask = torch.isin(state.system_idx, system_indices) - - # Filter attributes - filtered_attrs = _filter_attrs_by_mask(state, atom_mask, system_mask) + system_indices = system_indices.reshape(-1) + cumsum = torch.cat( + ( + state.n_atoms_per_system.new_zeros(1), + state.n_atoms_per_system.cumsum(0), + ) + ) + atom_index_list = [ + torch.arange( + cumsum[sys_idx].item(), + cumsum[sys_idx + 1].item(), + device=state.device, + ) + for sys_idx in system_indices + ] + atom_indices = torch.cat(atom_index_list) - # Create the sliced state + atom_mask = torch.zeros(state.n_atoms, dtype=torch.bool, device=state.device) + atom_mask[atom_indices] = True + system_mask = torch.zeros(state.n_systems, dtype=torch.bool, device=state.device) + system_mask[system_indices] = True + filtered_attrs = dict(get_attrs_for_scope(state, "global")) + filtered_attrs["_constraints"] = [ + constraint.select_constraint(atom_mask, system_mask) + for constraint in copy.deepcopy(state.constraints) + ] + filtered_attrs["_constraints"] = [ + c for c in filtered_attrs["_constraints"] if c is not None + ] + for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): + if attr_name == "system_idx": + old_system_indices = attr_value[atom_indices] + inv = torch.empty( + system_indices.max().item() + 1, + device=state.device, + dtype=torch.long, + ) + inv[system_indices] = torch.arange(len(system_indices), device=state.device) + filtered_attrs[attr_name] = inv[old_system_indices] + else: + filtered_attrs[attr_name] = attr_value[atom_indices] + for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if isinstance(attr_value, torch.Tensor): + filtered_attrs[attr_name] = attr_value[system_indices] + else: + filtered_attrs[attr_name] = attr_value return type(state)(**filtered_attrs) # type: ignore[invalid-return-type]