From fa0081fee0f3289b244ab6db109c4f02f96274d6 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 7 Nov 2025 17:13:08 +0100 Subject: [PATCH 1/9] Draft for the interface --- .../metatomic/torch/ase_calculator.py | 98 ++++++++++++++++++- .../metatomic_torch/metatomic/torch/model.py | 67 +++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 7c04989f..78c56514 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -43,6 +43,32 @@ "float64": torch.float64, } +# `Atoms` properties that can be get with `ase.Atoms.get_property` +# See: https://gitlab.com/ase/ase/-/blob/master/ase/calculators/abc.py#L12 +MIXIN_PROPERTIES = [ + "free_energy", + "energy", + "energies", + "forces", + "stress", + "stresses", + "dipole", + "charges", + "magmom", + "magmoms", +] + +# `Atoms` properties that are stored in the `ase.Atoms.arrays` dict +# See: https://gitlab.com/ase/ase/-/blob/master/ase/atoms.py#L29 +ARRAY_PROPERTIES = [ + "numbers", + "positions", + "momenta", + "masses", + "initial_magmons", + "initial_charges", +] + class MetatomicCalculator(ase.calculators.calculator.Calculator): """ @@ -237,9 +263,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -317,6 +343,7 @@ def run_model( :param outputs: outputs of the model that should be predicted :param selected_atoms: subset of atoms on which to run the calculation """ + print("Running") if isinstance(atoms, ase.Atoms): atoms_list = [atoms] else: @@ -339,6 +366,13 @@ def run_model( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) + # Get the additional inputs requested by the model + for option in self._model.requested_additional_inputs(): + print(option) + input_tensormap = _get_ase_additional_input( + atoms, option, dtype=self._dtype, device=self._device + ) + system.add_data(option, input_tensormap) systems.append(system) available_outputs = self._model.capabilities().outputs @@ -447,7 +481,27 @@ def calculate( types, positions, cell, pbc = _ase_to_torch_data( atoms=atoms, dtype=self._dtype, device=self._device + ) # , velocities, masses + '''velocities_block = TensorBlock( + velocities[None, :, :], + samples=Labels(["system"], torch.tensor([[0]])), + components=[Labels.range("atom", velocities.shape[0])], + properties=Labels.range("component", 3), ) + velocities_map = TensorMap( + Labels(["velocities"], torch.tensor([[0]])), + [velocities_block], + ) + masses_block = TensorBlock( + masses[None, :, :], + samples=Labels(["system"], torch.tensor([[0]])), + components=[Labels.range("atom", masses.shape[0])], + properties=Labels.range("mass", 1), + ) + mass_map = TensorMap( + Labels(["mass"], torch.tensor([[0]])), + [masses_block], + )''' do_backward = False if calculate_forces and not self.parameters["non_conservative"]: @@ -475,7 +529,8 @@ def calculate( with record_function("MetatomicCalculator::compute_neighbors"): # convert from ase.Atoms to metatomic.torch.System system = System(types, positions, cell, pbc) - + # system.add_data("velocities", velocities_map) + # system.add_data("mass", mass_map) for options in self._model.requested_neighbor_lists(): neighbors = _compute_ase_neighbors( atoms, options, dtype=self._dtype, device=self._device @@ -486,6 +541,12 @@ def calculate( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) + for option in self._model.requested_additional_inputs(): + print(option) + input_tensormap = _get_ase_additional_input( + atoms, option, dtype=self._dtype, device=self._device + ) + system.add_data(option, input_tensormap) # no `record_function` here, this will be handled by AtomisticModel outputs = self._model( @@ -906,6 +967,31 @@ def _compute_ase_neighbors(atoms, options, dtype, device): ) +def _get_ase_additional_input( + atoms: ase.Atoms, + option: str, + dtype: torch.dtype, + device: torch.device, +) -> "TensorMap": + if option in MIXIN_PROPERTIES: + values = atoms.get_properties(option) + elif option in ARRAY_PROPERTIES: + values = atoms.arrays[option] + else: + raise NotImplementedError + tblock = TensorBlock( + torch.tensor(values[None, :, :] if len(values.shape == 2) else values[None, None, :]), + samples=Labels(["system"], torch.tensor([[0]])), + components=[Labels.range("atom", values.shape[0])], + properties=Labels.range("components", values.shape[1]), + ) + tmap = TensorMap( + Labels([option], torch.tensor([[0]])), + [tblock], + ) + return tmap.to(dtype=dtype, device=device) + + def _ase_to_torch_data(atoms, dtype, device): """Get the positions, cell and pbc from ASE atoms as torch tensors""" @@ -913,10 +999,12 @@ def _ase_to_torch_data(atoms, dtype, device): positions = torch.from_numpy(atoms.positions).to(dtype=dtype, device=device) cell = torch.zeros((3, 3), dtype=dtype, device=device) pbc = torch.tensor(atoms.pbc, dtype=torch.bool, device=device) + # velocities = torch.from_numpy(atoms.get_velocities()).to(dtype=dtype, device=device) + # masses = torch.from_numpy(atoms.get_masses()).to(dtype=dtype, device=device) cell[pbc] = torch.tensor(atoms.cell[atoms.pbc], dtype=dtype, device=device) - return types, positions, cell, pbc + return types, positions, cell, pbc # , velocities, masses def _full_3x3_to_voigt_6_stress(stress): diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index eb1c2703..554d9ab9 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -176,6 +176,11 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: the systems before calling the model. """ + def requested_additional_inputs(self) -> List[str]: + """ + Optional method declaring which additional inputs this model requires. + """ + class AtomisticModel(torch.nn.Module): """ @@ -292,6 +297,7 @@ class AtomisticModel(torch.nn.Module): # Some annotation to make the TorchScript compiler happy _requested_neighbor_lists: List[NeighborListOptions] + _requested_additional_inputs: List[str] def __init__( self, @@ -327,6 +333,15 @@ def __init__( ) # ============================================================================ # + # recursively explore `module` to get all the requested_additional_inputs + self._requested_additional_inputs = [] + _get_requested_additional_inputs( + module, + self.module.__class__.__name__, + self._requested_additional_inputs, + ) + # ============================================================================ # + self._metadata = metadata self._capabilities = capabilities @@ -380,6 +395,14 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: """ return self._requested_neighbor_lists + @torch.jit.export + def requested_additional_inputs(self) -> List[str]: + """ + Get the additional inputs required by the exported model or any of the child + module. + """ + return self._requested_additional_inputs + def forward( self, systems: List[System], @@ -410,6 +433,7 @@ def forward( _check_inputs( capabilities=self._capabilities, requested_neighbor_lists=self._requested_neighbor_lists, + requested_additional_inputs=self._requested_additional_inputs, systems=systems, options=options, expected_dtype=self._model_dtype, @@ -622,6 +646,33 @@ def _get_requested_neighbor_lists( ) +def _get_requested_additional_inputs( + module: torch.nn.Module, + module_name: str, + requested: List[str], +): + if hasattr(module, "requested_additional_inputs"): + for new_options in module.requested_additional_inputs(): + # new_options.add_requestor(module_name) + + already_requested = False + for existing in requested: + if existing == new_options: + already_requested = True + # for requestor in new_options.requestors(): + # existing.append(requestor) + + if not already_requested: + requested.append(new_options) + + for child_name, child in module.named_children(): + _get_requested_additional_inputs( + module=child, + module_name=module_name + "." + child_name, + requested=requested, + ) + + def _check_annotation(module: torch.nn.Module): if isinstance(module, torch.jit.RecursiveScriptModule): _check_annotation_torchscript(module) @@ -750,6 +801,7 @@ def _check_annotation_python(module: torch.nn.Module): def _check_inputs( capabilities: ModelCapabilities, requested_neighbor_lists: List[NeighborListOptions], + requested_additional_inputs: List[str], systems: List[System], options: ModelEvaluationOptions, expected_dtype: torch.dtype, @@ -847,6 +899,21 @@ def _check_inputs( f"a list for {request}, but it was not computed and stored " "in the system" ) + + # Check additional inputs + known_additional_inputs = system.known_data() + for request in requested_additional_inputs: + found = False + for known in known_additional_inputs: + if request == known: + found = True + + if not found: + raise ValueError( + "missing additional input in the system: the model requested " + f"a list for {request}, but it was not computed and stored " + "in the system" + ) def _convert_systems_units( From 42f69cdd94a5531d90fc3bf3ce954bf408220b2c Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Wed, 19 Nov 2025 22:58:31 +0100 Subject: [PATCH 2/9] Add a test for additional inputs --- .../include/metatomic/torch/model_output.hpp | 0 .../metatomic/torch/ase_calculator.py | 60 ++++++------------- .../metatomic_torch/metatomic/torch/model.py | 2 +- .../metatomic_torch/tests/ase_calculator.py | 50 ++++++++++++++++ 4 files changed, 70 insertions(+), 42 deletions(-) delete mode 100644 metatomic-torch/include/metatomic/torch/model_output.hpp diff --git a/metatomic-torch/include/metatomic/torch/model_output.hpp b/metatomic-torch/include/metatomic/torch/model_output.hpp deleted file mode 100644 index e69de29b..00000000 diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 78c56514..26775c91 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -47,11 +47,11 @@ # See: https://gitlab.com/ase/ase/-/blob/master/ase/calculators/abc.py#L12 MIXIN_PROPERTIES = [ "free_energy", - "energy", - "energies", - "forces", - "stress", - "stresses", + # "energy", + # "energies", + # "forces", + # "stress", + # "stresses", "dipole", "charges", "magmom", @@ -62,10 +62,10 @@ # See: https://gitlab.com/ase/ase/-/blob/master/ase/atoms.py#L29 ARRAY_PROPERTIES = [ "numbers", - "positions", + # "positions", "momenta", "masses", - "initial_magmons", + "initial_magmoms", "initial_charges", ] @@ -263,9 +263,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs @@ -343,7 +343,6 @@ def run_model( :param outputs: outputs of the model that should be predicted :param selected_atoms: subset of atoms on which to run the calculation """ - print("Running") if isinstance(atoms, ase.Atoms): atoms_list = [atoms] else: @@ -368,7 +367,6 @@ def run_model( system.add_neighbor_list(options, neighbors) # Get the additional inputs requested by the model for option in self._model.requested_additional_inputs(): - print(option) input_tensormap = _get_ase_additional_input( atoms, option, dtype=self._dtype, device=self._device ) @@ -481,27 +479,7 @@ def calculate( types, positions, cell, pbc = _ase_to_torch_data( atoms=atoms, dtype=self._dtype, device=self._device - ) # , velocities, masses - '''velocities_block = TensorBlock( - velocities[None, :, :], - samples=Labels(["system"], torch.tensor([[0]])), - components=[Labels.range("atom", velocities.shape[0])], - properties=Labels.range("component", 3), - ) - velocities_map = TensorMap( - Labels(["velocities"], torch.tensor([[0]])), - [velocities_block], ) - masses_block = TensorBlock( - masses[None, :, :], - samples=Labels(["system"], torch.tensor([[0]])), - components=[Labels.range("atom", masses.shape[0])], - properties=Labels.range("mass", 1), - ) - mass_map = TensorMap( - Labels(["mass"], torch.tensor([[0]])), - [masses_block], - )''' do_backward = False if calculate_forces and not self.parameters["non_conservative"]: @@ -542,7 +520,6 @@ def calculate( ) system.add_neighbor_list(options, neighbors) for option in self._model.requested_additional_inputs(): - print(option) input_tensormap = _get_ase_additional_input( atoms, option, dtype=self._dtype, device=self._device ) @@ -978,12 +955,15 @@ def _get_ase_additional_input( elif option in ARRAY_PROPERTIES: values = atoms.arrays[option] else: - raise NotImplementedError + raise ValueError(f"The property {option} is not available in `ase`.") + values = torch.tensor( + values[:, :, None] if values.ndim == 2 else values[None, :, None] + ) tblock = TensorBlock( - torch.tensor(values[None, :, :] if len(values.shape == 2) else values[None, None, :]), - samples=Labels(["system"], torch.tensor([[0]])), - components=[Labels.range("atom", values.shape[0])], - properties=Labels.range("components", values.shape[1]), + values, + samples=Labels.range("atoms", values.shape[0]), + components=[Labels.range("components", values.shape[1])], + properties=Labels(["property"], torch.tensor([[0]])), ) tmap = TensorMap( Labels([option], torch.tensor([[0]])), @@ -999,12 +979,10 @@ def _ase_to_torch_data(atoms, dtype, device): positions = torch.from_numpy(atoms.positions).to(dtype=dtype, device=device) cell = torch.zeros((3, 3), dtype=dtype, device=device) pbc = torch.tensor(atoms.pbc, dtype=torch.bool, device=device) - # velocities = torch.from_numpy(atoms.get_velocities()).to(dtype=dtype, device=device) - # masses = torch.from_numpy(atoms.get_masses()).to(dtype=dtype, device=device) cell[pbc] = torch.tensor(atoms.cell[atoms.pbc], dtype=dtype, device=device) - return types, positions, cell, pbc # , velocities, masses + return types, positions, cell, pbc def _full_3x3_to_voigt_6_stress(stress): diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index 554d9ab9..6a00c662 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -899,7 +899,7 @@ def _check_inputs( f"a list for {request}, but it was not computed and stored " "in the system" ) - + # Check additional inputs known_additional_inputs = system.known_data() for request in requested_additional_inputs: diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index e654a579..4265f4cf 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -809,3 +809,53 @@ def forward( match = "does not support energy computation" with pytest.raises(ValueError, match=match): calc.compute_energy(atoms) + + +class AdditionalInputModel(torch.nn.Module): + def __init__(self, additional_inputs): + super().__init__() + self._additional_inputs = additional_inputs + + def requested_additional_inputs(self) -> List[str]: + return self._additional_inputs + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + return { + ("extra::" + additional_input): systems[0].get_data(additional_input) + for additional_input in self._additional_inputs + } + + +def test_additional_input(atoms): + additional_inputs = ["initial_magmoms", "numbers"] + outputs = { + ("extra::" + additional_input): ModelOutput( + quantity=additional_input, per_atom=True + ) + for additional_input in additional_inputs + } + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(additional_inputs).eval(), ModelMetadata(), capabilities + ) + calculator = MetatomicCalculator(model) + results = calculator.run_model(atoms, outputs) + for k, v in results.items(): + head, prop = k.split("::") + assert head == "extra" + assert prop in additional_inputs + assert len(v.keys.names) == 1 + assert v.keys.names[0] == prop + assert np.allclose(v[0].values.numpy(), atoms.arrays[prop]) From 87926917a1a87fff2a2837de957c213b1c340693 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Thu, 4 Dec 2025 14:49:00 +0100 Subject: [PATCH 3/9] Include `ModelOutput` as a part of input request --- .../metatomic/torch/ase_calculator.py | 10 ++--- .../metatomic_torch/metatomic/torch/model.py | 39 +++++++++---------- .../metatomic_torch/tests/ase_calculator.py | 4 +- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 26775c91..fc1ad689 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -366,8 +366,8 @@ def run_model( ) system.add_neighbor_list(options, neighbors) # Get the additional inputs requested by the model - for option in self._model.requested_additional_inputs(): - input_tensormap = _get_ase_additional_input( + for option in self._model.requested_inputs(): + input_tensormap = _get_ase_input( atoms, option, dtype=self._dtype, device=self._device ) system.add_data(option, input_tensormap) @@ -519,8 +519,8 @@ def calculate( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) - for option in self._model.requested_additional_inputs(): - input_tensormap = _get_ase_additional_input( + for option in self._model.requested_inputs(): + input_tensormap = _get_ase_input( atoms, option, dtype=self._dtype, device=self._device ) system.add_data(option, input_tensormap) @@ -944,7 +944,7 @@ def _compute_ase_neighbors(atoms, options, dtype, device): ) -def _get_ase_additional_input( +def _get_ase_input( atoms: ase.Atoms, option: str, dtype: torch.dtype, diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index 6a00c662..c5220394 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -176,9 +176,9 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: the systems before calling the model. """ - def requested_additional_inputs(self) -> List[str]: + def requested_inputs(self) -> Dict[str, ModelOutput]: """ - Optional method declaring which additional inputs this model requires. + Optional method declaring which inputs this model requires. """ @@ -297,7 +297,7 @@ class AtomisticModel(torch.nn.Module): # Some annotation to make the TorchScript compiler happy _requested_neighbor_lists: List[NeighborListOptions] - _requested_additional_inputs: List[str] + _requested_inputs: Dict[str, ModelOutput] def __init__( self, @@ -334,11 +334,11 @@ def __init__( # ============================================================================ # # recursively explore `module` to get all the requested_additional_inputs - self._requested_additional_inputs = [] - _get_requested_additional_inputs( + self._requested_inputs = {} + _get_requested_inputs( module, self.module.__class__.__name__, - self._requested_additional_inputs, + self._requested_inputs, ) # ============================================================================ # @@ -396,12 +396,12 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: return self._requested_neighbor_lists @torch.jit.export - def requested_additional_inputs(self) -> List[str]: + def requested_inputs(self) -> Dict[str, ModelOutput]: """ Get the additional inputs required by the exported model or any of the child module. """ - return self._requested_additional_inputs + return self._requested_inputs def forward( self, @@ -433,7 +433,7 @@ def forward( _check_inputs( capabilities=self._capabilities, requested_neighbor_lists=self._requested_neighbor_lists, - requested_additional_inputs=self._requested_additional_inputs, + requested_inputs=self._requested_inputs, systems=systems, options=options, expected_dtype=self._model_dtype, @@ -646,27 +646,24 @@ def _get_requested_neighbor_lists( ) -def _get_requested_additional_inputs( +def _get_requested_inputs( module: torch.nn.Module, module_name: str, - requested: List[str], + requested: Dict[str, ModelOutput], ): - if hasattr(module, "requested_additional_inputs"): - for new_options in module.requested_additional_inputs(): - # new_options.add_requestor(module_name) - + if hasattr(module, "requested_inputs"): + requested_inputs = module.requested_inputs() + for new_options in requested_inputs: already_requested = False for existing in requested: if existing == new_options: already_requested = True - # for requestor in new_options.requestors(): - # existing.append(requestor) if not already_requested: - requested.append(new_options) + requested[new_options] = requested_inputs[new_options] for child_name, child in module.named_children(): - _get_requested_additional_inputs( + _get_requested_inputs( module=child, module_name=module_name + "." + child_name, requested=requested, @@ -801,7 +798,7 @@ def _check_annotation_python(module: torch.nn.Module): def _check_inputs( capabilities: ModelCapabilities, requested_neighbor_lists: List[NeighborListOptions], - requested_additional_inputs: List[str], + requested_inputs: Dict[str, ModelOutput], systems: List[System], options: ModelEvaluationOptions, expected_dtype: torch.dtype, @@ -902,7 +899,7 @@ def _check_inputs( # Check additional inputs known_additional_inputs = system.known_data() - for request in requested_additional_inputs: + for request in requested_inputs: found = False for known in known_additional_inputs: if request == known: diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 4265f4cf..85335050 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -816,7 +816,7 @@ def __init__(self, additional_inputs): super().__init__() self._additional_inputs = additional_inputs - def requested_additional_inputs(self) -> List[str]: + def requested_inputs(self) -> Dict[str, ModelOutput]: return self._additional_inputs def forward( @@ -832,7 +832,7 @@ def forward( def test_additional_input(atoms): - additional_inputs = ["initial_magmoms", "numbers"] + additional_inputs = {"initial_magmoms": ModelOutput(), "numbers": ModelOutput()} outputs = { ("extra::" + additional_input): ModelOutput( quantity=additional_input, per_atom=True From e6f170a786841863b920938e5fae99d1dab877e5 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 5 Dec 2025 11:22:13 +0100 Subject: [PATCH 4/9] Add support for `TensorMap.info` --- metatomic-torch/src/model.cpp | 10 ++ .../metatomic/torch/ase_calculator.py | 95 +++++++++++++------ .../metatomic_torch/tests/ase_calculator.py | 4 +- 3 files changed, 79 insertions(+), 30 deletions(-) diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index efa77c9c..09c260c0 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1206,6 +1206,16 @@ static std::map KNOWN_QUANTITIES = { }, { // alternative names }}}, + {"mass", Quantity{/* name */ "mass", /* baseline */ "u ", { + {"u", 1.0}, + {"kilogram", 1.66053906892e-27}, + {"gram", 1.66053906892e-24}, + }, { + // alternative names + {"Dalton", "u"}, + {"kg", "kilogram"}, + {"g", "gram"}, + }}}, }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index fc1ad689..de0cfd3e 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -21,6 +21,7 @@ pick_device, pick_output, register_autograd_neighbors, + unit_conversion_factor, ) @@ -45,29 +46,38 @@ # `Atoms` properties that can be get with `ase.Atoms.get_property` # See: https://gitlab.com/ase/ase/-/blob/master/ase/calculators/abc.py#L12 -MIXIN_PROPERTIES = [ - "free_energy", +MIXIN_PROPERTIES = { + "free_energy": {}, # "energy", # "energies", # "forces", # "stress", # "stresses", - "dipole", - "charges", - "magmom", - "magmoms", -] + "dipole": {}, + "charges": {}, + "magmom": {}, + "magmoms": {}, +} # `Atoms` properties that are stored in the `ase.Atoms.arrays` dict # See: https://gitlab.com/ase/ase/-/blob/master/ase/atoms.py#L29 -ARRAY_PROPERTIES = [ - "numbers", +ARRAY_PROPERTIES = { + "numbers": { + "quantity": "atomic_number", + "unit": "", + }, # "positions", - "momenta", - "masses", - "initial_magmoms", - "initial_charges", -] + "momenta": { + "quantity": "momentum", + "unit": "(eV*u)^(1/2)", + }, + "masses": { + "quantity": "mass", + "unit": "u", + }, + "initial_magmoms": {}, + "initial_charges": {}, +} class MetatomicCalculator(ase.calculators.calculator.Calculator): @@ -366,11 +376,11 @@ def run_model( ) system.add_neighbor_list(options, neighbors) # Get the additional inputs requested by the model - for option in self._model.requested_inputs(): + for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( - atoms, option, dtype=self._dtype, device=self._device + atoms, quantity, option, dtype=self._dtype, device=self._device ) - system.add_data(option, input_tensormap) + system.add_data(quantity, input_tensormap) systems.append(system) available_outputs = self._model.capabilities().outputs @@ -519,11 +529,11 @@ def calculate( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) - for option in self._model.requested_inputs(): + for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( - atoms, option, dtype=self._dtype, device=self._device + atoms, quantity, option, dtype=self._dtype, device=self._device ) - system.add_data(option, input_tensormap) + system.add_data(quantity, input_tensormap) # no `record_function` here, this will be handled by AtomisticModel outputs = self._model( @@ -946,19 +956,43 @@ def _compute_ase_neighbors(atoms, options, dtype, device): def _get_ase_input( atoms: ase.Atoms, - option: str, + quantity: str, + option: ModelOutput, dtype: torch.dtype, device: torch.device, ) -> "TensorMap": - if option in MIXIN_PROPERTIES: - values = atoms.get_properties(option) - elif option in ARRAY_PROPERTIES: - values = atoms.arrays[option] + if quantity in MIXIN_PROPERTIES: + if len(MIXIN_PROPERTIES[quantity]) == 0: + raise NotImplementedError( + f"Though the property {quantity} is available in `ase`, it is " + "currently not supported by metatomic." + ) + values = atoms.get_properties(quantity) + infos = MIXIN_PROPERTIES[quantity] + elif quantity in ARRAY_PROPERTIES: + if len(ARRAY_PROPERTIES[quantity]) == 0: + raise NotImplementedError( + f"Though the property {quantity} is available in `ase`, it is " + "currently not supported by metatomic." + ) + values = atoms.arrays[quantity] + infos = ARRAY_PROPERTIES[quantity] else: - raise ValueError(f"The property {option} is not available in `ase`.") - values = torch.tensor( - values[:, :, None] if values.ndim == 2 else values[None, :, None] + raise ValueError(f"The property {quantity} is not available in `ase`.") + + if infos["unit"] != option.unit: + conversion = unit_conversion_factor( + quantity, + from_unit=infos["unit"], + to_unit=option.unit, + ) + else: + conversion = 1.0 + values = ( + torch.tensor(values[:, :, None] if values.ndim == 2 else values[None, :, None]) + * conversion ) + tblock = TensorBlock( values, samples=Labels.range("atoms", values.shape[0]), @@ -966,9 +1000,12 @@ def _get_ase_input( properties=Labels(["property"], torch.tensor([[0]])), ) tmap = TensorMap( - Labels([option], torch.tensor([[0]])), + Labels([quantity], torch.tensor([[0]])), [tblock], ) + tmap.set_info("quantity", quantity) + tmap.set_info("unit", option.unit) + return tmap.to(dtype=dtype, device=device) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 85335050..fcc14b7a 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -832,7 +832,9 @@ def forward( def test_additional_input(atoms): - additional_inputs = {"initial_magmoms": ModelOutput(), "numbers": ModelOutput()} + additional_inputs = { + "numbers": ModelOutput(quantity="atomic_number", unit="", per_atom=True) + } outputs = { ("extra::" + additional_input): ModelOutput( quantity=additional_input, per_atom=True From c7719fc1d368917a884c3bd8c5bbbf3b72d9d21e Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 8 Dec 2025 11:21:21 +0100 Subject: [PATCH 5/9] Support velocities as an additional input --- metatomic-torch/src/model.cpp | 8 ++++ .../metatomic/torch/ase_calculator.py | 45 +++++++------------ .../metatomic_torch/tests/ase_calculator.py | 39 +++++++++------- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index 09c260c0..95a41016 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1216,6 +1216,14 @@ static std::map KNOWN_QUANTITIES = { {"kg", "kilogram"}, {"g", "gram"}, }}}, + {"velocity", Quantity{/* name */ "velocity", /* baseline */ "nm/fs", { + {"nm/fs", 1.0}, + {"A/fs", 1e1}, + {"m/s", 1e6}, + {"nm/ps", 1e3}, + }, { + // alternative names + }}} }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index de0cfd3e..bc9a73d2 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -44,39 +44,35 @@ "float64": torch.float64, } -# `Atoms` properties that can be get with `ase.Atoms.get_property` -# See: https://gitlab.com/ase/ase/-/blob/master/ase/calculators/abc.py#L12 -MIXIN_PROPERTIES = { - "free_energy": {}, - # "energy", - # "energies", - # "forces", - # "stress", - # "stresses", - "dipole": {}, - "charges": {}, - "magmom": {}, - "magmoms": {}, -} - -# `Atoms` properties that are stored in the `ase.Atoms.arrays` dict -# See: https://gitlab.com/ase/ase/-/blob/master/ase/atoms.py#L29 ARRAY_PROPERTIES = { "numbers": { + "getter": lambda atoms: atoms.get_atomic_numbers(), "quantity": "atomic_number", "unit": "", }, # "positions", "momenta": { + "getter": lambda atoms: atoms.get_momenta(), "quantity": "momentum", "unit": "(eV*u)^(1/2)", }, "masses": { + "getter": lambda atoms: atoms.get_masses(), "quantity": "mass", "unit": "u", }, + "velocities": { + "getter": lambda atoms: atoms.get_velocities(), + "quantity": "velocity", + "unit": "nm/fs", + }, "initial_magmoms": {}, + "magmom": {}, + "magmoms": {}, "initial_charges": {}, + "charges": {}, + "dipole": {}, + "free_energy": {}, } @@ -376,6 +372,7 @@ def run_model( ) system.add_neighbor_list(options, neighbors) # Get the additional inputs requested by the model + print(f"Getting additional inputs: {self._model.requested_inputs()}") for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( atoms, quantity, option, dtype=self._dtype, device=self._device @@ -961,28 +958,20 @@ def _get_ase_input( dtype: torch.dtype, device: torch.device, ) -> "TensorMap": - if quantity in MIXIN_PROPERTIES: - if len(MIXIN_PROPERTIES[quantity]) == 0: - raise NotImplementedError( - f"Though the property {quantity} is available in `ase`, it is " - "currently not supported by metatomic." - ) - values = atoms.get_properties(quantity) - infos = MIXIN_PROPERTIES[quantity] - elif quantity in ARRAY_PROPERTIES: + if quantity in ARRAY_PROPERTIES: if len(ARRAY_PROPERTIES[quantity]) == 0: raise NotImplementedError( f"Though the property {quantity} is available in `ase`, it is " "currently not supported by metatomic." ) - values = atoms.arrays[quantity] infos = ARRAY_PROPERTIES[quantity] else: raise ValueError(f"The property {quantity} is not available in `ase`.") + values = infos["getter"](atoms) if infos["unit"] != option.unit: conversion = unit_conversion_factor( - quantity, + option.quantity, from_unit=infos["unit"], to_unit=option.unit, ) diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index fcc14b7a..67210b3c 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -14,6 +14,7 @@ import pytest import torch from ase.calculators.calculator import PropertyNotImplementedError +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -25,6 +26,7 @@ System, ) from metatomic.torch.ase_calculator import ( + ARRAY_PROPERTIES, MetatomicCalculator, _compute_ase_neighbors, _full_3x3_to_voigt_6_stress, @@ -812,12 +814,12 @@ def forward( class AdditionalInputModel(torch.nn.Module): - def __init__(self, additional_inputs): + def __init__(self, inputs): super().__init__() - self._additional_inputs = additional_inputs + self._requested_inputs = inputs - def requested_inputs(self) -> Dict[str, ModelOutput]: - return self._additional_inputs + def requested_inputs(self) -> List[str]: + return self._requested_inputs def forward( self, @@ -825,22 +827,19 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: + print(systems[0].known_data()) return { - ("extra::" + additional_input): systems[0].get_data(additional_input) - for additional_input in self._additional_inputs + ("extra::" + input): systems[0].get_data(input) + for input in self._requested_inputs } def test_additional_input(atoms): - additional_inputs = { - "numbers": ModelOutput(quantity="atomic_number", unit="", per_atom=True) - } - outputs = { - ("extra::" + additional_input): ModelOutput( - quantity=additional_input, per_atom=True - ) - for additional_input in additional_inputs + inputs = { + "numbers": ModelOutput(quantity="atomic_number", unit="", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), } + outputs = {("extra::" + prop): inputs[prop] for prop in inputs} capabilities = ModelCapabilities( outputs=outputs, atomic_types=[28], @@ -850,14 +849,20 @@ def test_additional_input(atoms): ) model = AtomisticModel( - AdditionalInputModel(additional_inputs).eval(), ModelMetadata(), capabilities + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities ) + MaxwellBoltzmannDistribution(atoms, temperature_K=300.0) calculator = MetatomicCalculator(model) results = calculator.run_model(atoms, outputs) for k, v in results.items(): head, prop = k.split("::") assert head == "extra" - assert prop in additional_inputs + assert prop in inputs assert len(v.keys.names) == 1 assert v.keys.names[0] == prop - assert np.allclose(v[0].values.numpy(), atoms.arrays[prop]) + shape = v[0].values.numpy().shape + assert np.allclose( + v[0].values.numpy(), + ARRAY_PROPERTIES[prop]["getter"](atoms).reshape(shape) + * (10 if prop == "velocities" else 1), + ) From 8612d72355101456cefc5fd04e2f7aff9078eed5 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 12 Dec 2025 10:36:37 +0100 Subject: [PATCH 6/9] Fix --- metatomic-torch/tests/models.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 3401c23f..ec8996f1 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -109,7 +109,7 @@ TEST_CASE("Models metadata") { struct WarningHandler: public torch::WarningHandler { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { - CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported"); + CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length mass momentum pressure velocity] are supported"); } }; From e3340cd8f81bfbfaaed8666d12849d934eb72856 Mon Sep 17 00:00:00 2001 From: Qianjun Xu <92628709+GardevoirX@users.noreply.github.com> Date: Sun, 14 Dec 2025 21:51:39 +0100 Subject: [PATCH 7/9] Apply suggestions from code review Co-authored-by: Guillaume Fraux --- .../metatomic_torch/metatomic/torch/ase_calculator.py | 10 +--------- python/metatomic_torch/tests/ase_calculator.py | 1 - 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index bc9a73d2..d1c0f51c 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -45,12 +45,6 @@ } ARRAY_PROPERTIES = { - "numbers": { - "getter": lambda atoms: atoms.get_atomic_numbers(), - "quantity": "atomic_number", - "unit": "", - }, - # "positions", "momenta": { "getter": lambda atoms: atoms.get_momenta(), "quantity": "momentum", @@ -514,8 +508,6 @@ def calculate( with record_function("MetatomicCalculator::compute_neighbors"): # convert from ase.Atoms to metatomic.torch.System system = System(types, positions, cell, pbc) - # system.add_data("velocities", velocities_map) - # system.add_data("mass", mass_map) for options in self._model.requested_neighbor_lists(): neighbors = _compute_ase_neighbors( atoms, options, dtype=self._dtype, device=self._device @@ -966,7 +958,7 @@ def _get_ase_input( ) infos = ARRAY_PROPERTIES[quantity] else: - raise ValueError(f"The property {quantity} is not available in `ase`.") + raise ValueError(f"The model requested '{quantity}', which is not available in `ase`.") values = infos["getter"](atoms) if infos["unit"] != option.unit: diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 67210b3c..dd6463c5 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -827,7 +827,6 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels] = None, ) -> Dict[str, TensorMap]: - print(systems[0].known_data()) return { ("extra::" + input): systems[0].get_data(input) for input in self._requested_inputs From 3edf9312d6d9f83e04ca5f30125e27774319176a Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 15 Dec 2025 11:32:53 +0100 Subject: [PATCH 8/9] Update tensormap naming and tests --- .../metatomic/torch/ase_calculator.py | 19 +++++++++++-------- .../metatomic_torch/tests/ase_calculator.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index d1c0f51c..1762fc8a 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -366,7 +366,6 @@ def run_model( ) system.add_neighbor_list(options, neighbors) # Get the additional inputs requested by the model - print(f"Getting additional inputs: {self._model.requested_inputs()}") for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( atoms, quantity, option, dtype=self._dtype, device=self._device @@ -958,7 +957,9 @@ def _get_ase_input( ) infos = ARRAY_PROPERTIES[quantity] else: - raise ValueError(f"The model requested '{quantity}', which is not available in `ase`.") + raise ValueError( + f"The model requested '{quantity}', which is not available in `ase`." + ) values = infos["getter"](atoms) if infos["unit"] != option.unit: @@ -977,17 +978,19 @@ def _get_ase_input( tblock = TensorBlock( values, samples=Labels.range("atoms", values.shape[0]), - components=[Labels.range("components", values.shape[1])], - properties=Labels(["property"], torch.tensor([[0]])), + components=[Labels.range("components", values.shape[1])] + if values.shape[1] != 1 + else [], + properties=Labels([option.quantity], torch.tensor([[0]])), ) tmap = TensorMap( - Labels([quantity], torch.tensor([[0]])), + Labels(["_"], torch.tensor([[0]])), [tblock], ) - tmap.set_info("quantity", quantity) + tmap.set_info("quantity", option.quantity) tmap.set_info("unit", option.unit) - - return tmap.to(dtype=dtype, device=device) + tmap.to(dtype=dtype, device=device) + return tmap def _ase_to_torch_data(atoms, dtype, device): diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index dd6463c5..4f3bdd0b 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -835,7 +835,7 @@ def forward( def test_additional_input(atoms): inputs = { - "numbers": ModelOutput(quantity="atomic_number", unit="", per_atom=True), + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), } outputs = {("extra::" + prop): inputs[prop] for prop in inputs} @@ -858,7 +858,7 @@ def test_additional_input(atoms): assert head == "extra" assert prop in inputs assert len(v.keys.names) == 1 - assert v.keys.names[0] == prop + assert v.get_info("quantity") == inputs[prop].quantity shape = v[0].values.numpy().shape assert np.allclose( v[0].values.numpy(), From b132ce6211726a1a4988f46c36a90996bcb67729 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 16 Dec 2025 12:44:10 +0100 Subject: [PATCH 9/9] Check the requested inputs with `_check_outputs` --- .../metatomic/torch/ase_calculator.py | 26 ++++++++----------- .../metatomic_torch/metatomic/torch/model.py | 14 ++++++++++ .../metatomic_torch/tests/ase_calculator.py | 10 +++---- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 1762fc8a..d6891ec6 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -44,20 +44,17 @@ "float64": torch.float64, } -ARRAY_PROPERTIES = { - "momenta": { +ARRAY_QUANTITIES = { + "momentum": { "getter": lambda atoms: atoms.get_momenta(), - "quantity": "momentum", "unit": "(eV*u)^(1/2)", }, - "masses": { + "mass": { "getter": lambda atoms: atoms.get_masses(), - "quantity": "mass", "unit": "u", }, - "velocities": { + "velocity": { "getter": lambda atoms: atoms.get_velocities(), - "quantity": "velocity", "unit": "nm/fs", }, "initial_magmoms": {}, @@ -368,7 +365,7 @@ def run_model( # Get the additional inputs requested by the model for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( - atoms, quantity, option, dtype=self._dtype, device=self._device + atoms, option, dtype=self._dtype, device=self._device ) system.add_data(quantity, input_tensormap) systems.append(system) @@ -519,7 +516,7 @@ def calculate( system.add_neighbor_list(options, neighbors) for quantity, option in self._model.requested_inputs().items(): input_tensormap = _get_ase_input( - atoms, quantity, option, dtype=self._dtype, device=self._device + atoms, option, dtype=self._dtype, device=self._device ) system.add_data(quantity, input_tensormap) @@ -944,21 +941,20 @@ def _compute_ase_neighbors(atoms, options, dtype, device): def _get_ase_input( atoms: ase.Atoms, - quantity: str, option: ModelOutput, dtype: torch.dtype, device: torch.device, ) -> "TensorMap": - if quantity in ARRAY_PROPERTIES: - if len(ARRAY_PROPERTIES[quantity]) == 0: + if option.quantity in ARRAY_QUANTITIES: + if len(ARRAY_QUANTITIES[option.quantity]) == 0: raise NotImplementedError( - f"Though the property {quantity} is available in `ase`, it is " + f"Though the quantity {option.quantity} is available in `ase`, it is " "currently not supported by metatomic." ) - infos = ARRAY_PROPERTIES[quantity] + infos = ARRAY_QUANTITIES[option.quantity] else: raise ValueError( - f"The model requested '{quantity}', which is not available in `ase`." + f"The model requested '{option.quantity}', which is not available in `ase`." ) values = infos["getter"](atoms) diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index c5220394..db209d4a 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -438,6 +438,18 @@ def forward( options=options, expected_dtype=self._model_dtype, ) + # check the requested inputs stored in the `systems` + for system in systems: + system_inputs: Dict[str, TensorMap] = {} + for name in system.known_data(): + system_inputs[name] = system.get_data(name) + _check_outputs( + systems=[system], + requested=self._requested_inputs, + selected_atoms=options.selected_atoms, + outputs=system_inputs, + model_dtype=self._capabilities.dtype, + ) with record_function("AtomisticModel::check_atomic_types"): # always (i.e. even if check_consistency=False) check that the atomic types @@ -898,6 +910,8 @@ def _check_inputs( ) # Check additional inputs + # Might be problematic, this requires that only requested inputs are stored as + # the data pf the system known_additional_inputs = system.known_data() for request in requested_inputs: found = False diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 4f3bdd0b..96e8e21d 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -26,7 +26,7 @@ System, ) from metatomic.torch.ase_calculator import ( - ARRAY_PROPERTIES, + ARRAY_QUANTITIES, MetatomicCalculator, _compute_ase_neighbors, _full_3x3_to_voigt_6_stress, @@ -835,8 +835,8 @@ def forward( def test_additional_input(atoms): inputs = { - "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), - "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + "mass": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocity": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), } outputs = {("extra::" + prop): inputs[prop] for prop in inputs} capabilities = ModelCapabilities( @@ -862,6 +862,6 @@ def test_additional_input(atoms): shape = v[0].values.numpy().shape assert np.allclose( v[0].values.numpy(), - ARRAY_PROPERTIES[prop]["getter"](atoms).reshape(shape) - * (10 if prop == "velocities" else 1), + ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape) + * (10 if prop == "velocity" else 1), )