diff --git a/metatomic-torch/include/metatomic/torch/model_output.hpp b/metatomic-torch/include/metatomic/torch/model_output.hpp deleted file mode 100644 index e69de29b..00000000 diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index efa77c9c..95a41016 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1206,6 +1206,24 @@ static std::map KNOWN_QUANTITIES = { }, { // alternative names }}}, + {"mass", Quantity{/* name */ "mass", /* baseline */ "u ", { + {"u", 1.0}, + {"kilogram", 1.66053906892e-27}, + {"gram", 1.66053906892e-24}, + }, { + // alternative names + {"Dalton", "u"}, + {"kg", "kilogram"}, + {"g", "gram"}, + }}}, + {"velocity", Quantity{/* name */ "velocity", /* baseline */ "nm/fs", { + {"nm/fs", 1.0}, + {"A/fs", 1e1}, + {"m/s", 1e6}, + {"nm/ps", 1e3}, + }, { + // alternative names + }}} }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 3401c23f..ec8996f1 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -109,7 +109,7 @@ TEST_CASE("Models metadata") { struct WarningHandler: public torch::WarningHandler { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { - CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported"); + CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length mass momentum pressure velocity] are supported"); } }; diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 7c04989f..d6891ec6 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -21,6 +21,7 @@ pick_device, pick_output, register_autograd_neighbors, + unit_conversion_factor, ) @@ -43,6 +44,28 @@ "float64": torch.float64, } +ARRAY_QUANTITIES = { + "momentum": { + "getter": lambda atoms: atoms.get_momenta(), + "unit": "(eV*u)^(1/2)", + }, + "mass": { + "getter": lambda atoms: atoms.get_masses(), + "unit": "u", + }, + "velocity": { + "getter": lambda atoms: atoms.get_velocities(), + "unit": "nm/fs", + }, + "initial_magmoms": {}, + "magmom": {}, + "magmoms": {}, + "initial_charges": {}, + "charges": {}, + "dipole": {}, + "free_energy": {}, +} + class MetatomicCalculator(ase.calculators.calculator.Calculator): """ @@ -339,6 +362,12 @@ def run_model( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) + # Get the additional inputs requested by the model + for quantity, option in self._model.requested_inputs().items(): + input_tensormap = _get_ase_input( + atoms, option, dtype=self._dtype, device=self._device + ) + system.add_data(quantity, input_tensormap) systems.append(system) available_outputs = self._model.capabilities().outputs @@ -475,7 +504,6 @@ def calculate( with record_function("MetatomicCalculator::compute_neighbors"): # convert from ase.Atoms to metatomic.torch.System system = System(types, positions, cell, pbc) - for options in self._model.requested_neighbor_lists(): neighbors = _compute_ase_neighbors( atoms, options, dtype=self._dtype, device=self._device @@ -486,6 +514,11 @@ def calculate( check_consistency=self.parameters["check_consistency"], ) system.add_neighbor_list(options, neighbors) + for quantity, option in self._model.requested_inputs().items(): + input_tensormap = _get_ase_input( + atoms, option, dtype=self._dtype, device=self._device + ) + system.add_data(quantity, input_tensormap) # no `record_function` here, this will be handled by AtomisticModel outputs = self._model( @@ -906,6 +939,56 @@ def _compute_ase_neighbors(atoms, options, dtype, device): ) +def _get_ase_input( + atoms: ase.Atoms, + option: ModelOutput, + dtype: torch.dtype, + device: torch.device, +) -> "TensorMap": + if option.quantity in ARRAY_QUANTITIES: + if len(ARRAY_QUANTITIES[option.quantity]) == 0: + raise NotImplementedError( + f"Though the quantity {option.quantity} is available in `ase`, it is " + "currently not supported by metatomic." + ) + infos = ARRAY_QUANTITIES[option.quantity] + else: + raise ValueError( + f"The model requested '{option.quantity}', which is not available in `ase`." + ) + + values = infos["getter"](atoms) + if infos["unit"] != option.unit: + conversion = unit_conversion_factor( + option.quantity, + from_unit=infos["unit"], + to_unit=option.unit, + ) + else: + conversion = 1.0 + values = ( + torch.tensor(values[:, :, None] if values.ndim == 2 else values[None, :, None]) + * conversion + ) + + tblock = TensorBlock( + values, + samples=Labels.range("atoms", values.shape[0]), + components=[Labels.range("components", values.shape[1])] + if values.shape[1] != 1 + else [], + properties=Labels([option.quantity], torch.tensor([[0]])), + ) + tmap = TensorMap( + Labels(["_"], torch.tensor([[0]])), + [tblock], + ) + tmap.set_info("quantity", option.quantity) + tmap.set_info("unit", option.unit) + tmap.to(dtype=dtype, device=device) + return tmap + + def _ase_to_torch_data(atoms, dtype, device): """Get the positions, cell and pbc from ASE atoms as torch tensors""" diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index eb1c2703..db209d4a 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -176,6 +176,11 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: the systems before calling the model. """ + def requested_inputs(self) -> Dict[str, ModelOutput]: + """ + Optional method declaring which inputs this model requires. + """ + class AtomisticModel(torch.nn.Module): """ @@ -292,6 +297,7 @@ class AtomisticModel(torch.nn.Module): # Some annotation to make the TorchScript compiler happy _requested_neighbor_lists: List[NeighborListOptions] + _requested_inputs: Dict[str, ModelOutput] def __init__( self, @@ -327,6 +333,15 @@ def __init__( ) # ============================================================================ # + # recursively explore `module` to get all the requested_additional_inputs + self._requested_inputs = {} + _get_requested_inputs( + module, + self.module.__class__.__name__, + self._requested_inputs, + ) + # ============================================================================ # + self._metadata = metadata self._capabilities = capabilities @@ -380,6 +395,14 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: """ return self._requested_neighbor_lists + @torch.jit.export + def requested_inputs(self) -> Dict[str, ModelOutput]: + """ + Get the additional inputs required by the exported model or any of the child + module. + """ + return self._requested_inputs + def forward( self, systems: List[System], @@ -410,10 +433,23 @@ def forward( _check_inputs( capabilities=self._capabilities, requested_neighbor_lists=self._requested_neighbor_lists, + requested_inputs=self._requested_inputs, systems=systems, options=options, expected_dtype=self._model_dtype, ) + # check the requested inputs stored in the `systems` + for system in systems: + system_inputs: Dict[str, TensorMap] = {} + for name in system.known_data(): + system_inputs[name] = system.get_data(name) + _check_outputs( + systems=[system], + requested=self._requested_inputs, + selected_atoms=options.selected_atoms, + outputs=system_inputs, + model_dtype=self._capabilities.dtype, + ) with record_function("AtomisticModel::check_atomic_types"): # always (i.e. even if check_consistency=False) check that the atomic types @@ -622,6 +658,30 @@ def _get_requested_neighbor_lists( ) +def _get_requested_inputs( + module: torch.nn.Module, + module_name: str, + requested: Dict[str, ModelOutput], +): + if hasattr(module, "requested_inputs"): + requested_inputs = module.requested_inputs() + for new_options in requested_inputs: + already_requested = False + for existing in requested: + if existing == new_options: + already_requested = True + + if not already_requested: + requested[new_options] = requested_inputs[new_options] + + for child_name, child in module.named_children(): + _get_requested_inputs( + module=child, + module_name=module_name + "." + child_name, + requested=requested, + ) + + def _check_annotation(module: torch.nn.Module): if isinstance(module, torch.jit.RecursiveScriptModule): _check_annotation_torchscript(module) @@ -750,6 +810,7 @@ def _check_annotation_python(module: torch.nn.Module): def _check_inputs( capabilities: ModelCapabilities, requested_neighbor_lists: List[NeighborListOptions], + requested_inputs: Dict[str, ModelOutput], systems: List[System], options: ModelEvaluationOptions, expected_dtype: torch.dtype, @@ -848,6 +909,23 @@ def _check_inputs( "in the system" ) + # Check additional inputs + # Might be problematic, this requires that only requested inputs are stored as + # the data pf the system + known_additional_inputs = system.known_data() + for request in requested_inputs: + found = False + for known in known_additional_inputs: + if request == known: + found = True + + if not found: + raise ValueError( + "missing additional input in the system: the model requested " + f"a list for {request}, but it was not computed and stored " + "in the system" + ) + def _convert_systems_units( systems: List[System], diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index e654a579..96e8e21d 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -14,6 +14,7 @@ import pytest import torch from ase.calculators.calculator import PropertyNotImplementedError +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( @@ -25,6 +26,7 @@ System, ) from metatomic.torch.ase_calculator import ( + ARRAY_QUANTITIES, MetatomicCalculator, _compute_ase_neighbors, _full_3x3_to_voigt_6_stress, @@ -809,3 +811,57 @@ def forward( match = "does not support energy computation" with pytest.raises(ValueError, match=match): calc.compute_energy(atoms) + + +class AdditionalInputModel(torch.nn.Module): + def __init__(self, inputs): + super().__init__() + self._requested_inputs = inputs + + def requested_inputs(self) -> List[str]: + return self._requested_inputs + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + return { + ("extra::" + input): systems[0].get_data(input) + for input in self._requested_inputs + } + + +def test_additional_input(atoms): + inputs = { + "mass": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocity": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + outputs = {("extra::" + prop): inputs[prop] for prop in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + MaxwellBoltzmannDistribution(atoms, temperature_K=300.0) + calculator = MetatomicCalculator(model) + results = calculator.run_model(atoms, outputs) + for k, v in results.items(): + head, prop = k.split("::") + assert head == "extra" + assert prop in inputs + assert len(v.keys.names) == 1 + assert v.get_info("quantity") == inputs[prop].quantity + shape = v[0].values.numpy().shape + assert np.allclose( + v[0].values.numpy(), + ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape) + * (10 if prop == "velocity" else 1), + )