Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
18 changes: 18 additions & 0 deletions metatomic-torch/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,24 @@ static std::map<std::string, Quantity> KNOWN_QUANTITIES = {
}, {
// alternative names
}}},
{"mass", Quantity{/* name */ "mass", /* baseline */ "u ", {
{"u", 1.0},
{"kilogram", 1.66053906892e-27},
{"gram", 1.66053906892e-24},
}, {
// alternative names
{"Dalton", "u"},
{"kg", "kilogram"},
{"g", "gram"},
}}},
{"velocity", Quantity{/* name */ "velocity", /* baseline */ "nm/fs", {
{"nm/fs", 1.0},
{"A/fs", 1e1},
{"m/s", 1e6},
{"nm/ps", 1e3},
}, {
// alternative names
}}}
};

bool metatomic_torch::valid_quantity(const std::string& quantity) {
Expand Down
2 changes: 1 addition & 1 deletion metatomic-torch/tests/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TEST_CASE("Models metadata") {
struct WarningHandler: public torch::WarningHandler {
virtual ~WarningHandler() override = default;
void process(const torch::Warning& warning) override {
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length momentum pressure] are supported");
CHECK(warning.msg() == "unknown quantity 'unknown', only [energy force length mass momentum pressure velocity] are supported");
}
};

Expand Down
85 changes: 84 additions & 1 deletion python/metatomic_torch/metatomic/torch/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
pick_device,
pick_output,
register_autograd_neighbors,
unit_conversion_factor,
)


Expand All @@ -43,6 +44,28 @@
"float64": torch.float64,
}

ARRAY_QUANTITIES = {
"momentum": {
"getter": lambda atoms: atoms.get_momenta(),
"unit": "(eV*u)^(1/2)",
},
"mass": {
"getter": lambda atoms: atoms.get_masses(),
"unit": "u",
},
"velocity": {
"getter": lambda atoms: atoms.get_velocities(),
"unit": "nm/fs",
},
"initial_magmoms": {},
"magmom": {},
"magmoms": {},
"initial_charges": {},
"charges": {},
"dipole": {},
"free_energy": {},
}


class MetatomicCalculator(ase.calculators.calculator.Calculator):
"""
Expand Down Expand Up @@ -339,6 +362,12 @@ def run_model(
check_consistency=self.parameters["check_consistency"],
)
system.add_neighbor_list(options, neighbors)
# Get the additional inputs requested by the model
for quantity, option in self._model.requested_inputs().items():
input_tensormap = _get_ase_input(
atoms, option, dtype=self._dtype, device=self._device
)
system.add_data(quantity, input_tensormap)
systems.append(system)

available_outputs = self._model.capabilities().outputs
Expand Down Expand Up @@ -475,7 +504,6 @@ def calculate(
with record_function("MetatomicCalculator::compute_neighbors"):
# convert from ase.Atoms to metatomic.torch.System
system = System(types, positions, cell, pbc)

for options in self._model.requested_neighbor_lists():
neighbors = _compute_ase_neighbors(
atoms, options, dtype=self._dtype, device=self._device
Expand All @@ -486,6 +514,11 @@ def calculate(
check_consistency=self.parameters["check_consistency"],
)
system.add_neighbor_list(options, neighbors)
for quantity, option in self._model.requested_inputs().items():
input_tensormap = _get_ase_input(
atoms, option, dtype=self._dtype, device=self._device
)
system.add_data(quantity, input_tensormap)

# no `record_function` here, this will be handled by AtomisticModel
outputs = self._model(
Expand Down Expand Up @@ -906,6 +939,56 @@ def _compute_ase_neighbors(atoms, options, dtype, device):
)


def _get_ase_input(
atoms: ase.Atoms,
option: ModelOutput,
dtype: torch.dtype,
device: torch.device,
) -> "TensorMap":
if option.quantity in ARRAY_QUANTITIES:
if len(ARRAY_QUANTITIES[option.quantity]) == 0:
raise NotImplementedError(
f"Though the quantity {option.quantity} is available in `ase`, it is "
"currently not supported by metatomic."
)
infos = ARRAY_QUANTITIES[option.quantity]
else:
raise ValueError(
f"The model requested '{option.quantity}', which is not available in `ase`."
)

values = infos["getter"](atoms)
if infos["unit"] != option.unit:
conversion = unit_conversion_factor(
option.quantity,
from_unit=infos["unit"],
to_unit=option.unit,
)
else:
conversion = 1.0
values = (
torch.tensor(values[:, :, None] if values.ndim == 2 else values[None, :, None])
* conversion
)

tblock = TensorBlock(
values,
samples=Labels.range("atoms", values.shape[0]),
components=[Labels.range("components", values.shape[1])]
if values.shape[1] != 1
else [],
properties=Labels([option.quantity], torch.tensor([[0]])),
)
tmap = TensorMap(
Labels(["_"], torch.tensor([[0]])),
[tblock],
)
tmap.set_info("quantity", option.quantity)
tmap.set_info("unit", option.unit)
tmap.to(dtype=dtype, device=device)
return tmap


def _ase_to_torch_data(atoms, dtype, device):
"""Get the positions, cell and pbc from ASE atoms as torch tensors"""

Expand Down
78 changes: 78 additions & 0 deletions python/metatomic_torch/metatomic/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]:
the systems before calling the model.
"""

def requested_inputs(self) -> Dict[str, ModelOutput]:
"""
Optional method declaring which inputs this model requires.
"""


class AtomisticModel(torch.nn.Module):
"""
Expand Down Expand Up @@ -292,6 +297,7 @@ class AtomisticModel(torch.nn.Module):

# Some annotation to make the TorchScript compiler happy
_requested_neighbor_lists: List[NeighborListOptions]
_requested_inputs: Dict[str, ModelOutput]

def __init__(
self,
Expand Down Expand Up @@ -327,6 +333,15 @@ def __init__(
)
# ============================================================================ #

# recursively explore `module` to get all the requested_additional_inputs
self._requested_inputs = {}
_get_requested_inputs(
module,
self.module.__class__.__name__,
self._requested_inputs,
)
# ============================================================================ #

self._metadata = metadata
self._capabilities = capabilities

Expand Down Expand Up @@ -380,6 +395,14 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]:
"""
return self._requested_neighbor_lists

@torch.jit.export
def requested_inputs(self) -> Dict[str, ModelOutput]:
"""
Get the additional inputs required by the exported model or any of the child
module.
"""
return self._requested_inputs

def forward(
self,
systems: List[System],
Expand Down Expand Up @@ -410,10 +433,23 @@ def forward(
_check_inputs(
capabilities=self._capabilities,
requested_neighbor_lists=self._requested_neighbor_lists,
requested_inputs=self._requested_inputs,
systems=systems,
options=options,
expected_dtype=self._model_dtype,
)
# check the requested inputs stored in the `systems`
for system in systems:
system_inputs: Dict[str, TensorMap] = {}
for name in system.known_data():
system_inputs[name] = system.get_data(name)
_check_outputs(
systems=[system],
requested=self._requested_inputs,
selected_atoms=options.selected_atoms,
outputs=system_inputs,
model_dtype=self._capabilities.dtype,
)

with record_function("AtomisticModel::check_atomic_types"):
# always (i.e. even if check_consistency=False) check that the atomic types
Expand Down Expand Up @@ -622,6 +658,30 @@ def _get_requested_neighbor_lists(
)


def _get_requested_inputs(
module: torch.nn.Module,
module_name: str,
requested: Dict[str, ModelOutput],
):
if hasattr(module, "requested_inputs"):
requested_inputs = module.requested_inputs()
for new_options in requested_inputs:
already_requested = False
for existing in requested:
if existing == new_options:
already_requested = True

if not already_requested:
requested[new_options] = requested_inputs[new_options]

for child_name, child in module.named_children():
_get_requested_inputs(
module=child,
module_name=module_name + "." + child_name,
requested=requested,
)


def _check_annotation(module: torch.nn.Module):
if isinstance(module, torch.jit.RecursiveScriptModule):
_check_annotation_torchscript(module)
Expand Down Expand Up @@ -750,6 +810,7 @@ def _check_annotation_python(module: torch.nn.Module):
def _check_inputs(
capabilities: ModelCapabilities,
requested_neighbor_lists: List[NeighborListOptions],
requested_inputs: Dict[str, ModelOutput],
systems: List[System],
options: ModelEvaluationOptions,
expected_dtype: torch.dtype,
Expand Down Expand Up @@ -848,6 +909,23 @@ def _check_inputs(
"in the system"
)

# Check additional inputs
# Might be problematic, this requires that only requested inputs are stored as
# the data pf the system
known_additional_inputs = system.known_data()
for request in requested_inputs:
found = False
for known in known_additional_inputs:
if request == known:
found = True

if not found:
raise ValueError(
"missing additional input in the system: the model requested "
f"a list for {request}, but it was not computed and stored "
"in the system"
)


def _convert_systems_units(
systems: List[System],
Expand Down
56 changes: 56 additions & 0 deletions python/metatomic_torch/tests/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest
import torch
from ase.calculators.calculator import PropertyNotImplementedError
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from metatensor.torch import Labels, TensorBlock, TensorMap

from metatomic.torch import (
Expand All @@ -25,6 +26,7 @@
System,
)
from metatomic.torch.ase_calculator import (
ARRAY_QUANTITIES,
MetatomicCalculator,
_compute_ase_neighbors,
_full_3x3_to_voigt_6_stress,
Expand Down Expand Up @@ -809,3 +811,57 @@ def forward(
match = "does not support energy computation"
with pytest.raises(ValueError, match=match):
calc.compute_energy(atoms)


class AdditionalInputModel(torch.nn.Module):
def __init__(self, inputs):
super().__init__()
self._requested_inputs = inputs

def requested_inputs(self) -> List[str]:
return self._requested_inputs

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels] = None,
) -> Dict[str, TensorMap]:
return {
("extra::" + input): systems[0].get_data(input)
for input in self._requested_inputs
}


def test_additional_input(atoms):
inputs = {
"mass": ModelOutput(quantity="mass", unit="u", per_atom=True),
"velocity": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
}
outputs = {("extra::" + prop): inputs[prop] for prop in inputs}
capabilities = ModelCapabilities(
outputs=outputs,
atomic_types=[28],
interaction_range=0.0,
supported_devices=["cpu"],
dtype="float64",
)

model = AtomisticModel(
AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities
)
MaxwellBoltzmannDistribution(atoms, temperature_K=300.0)
calculator = MetatomicCalculator(model)
results = calculator.run_model(atoms, outputs)
for k, v in results.items():
head, prop = k.split("::")
assert head == "extra"
assert prop in inputs
assert len(v.keys.names) == 1
assert v.get_info("quantity") == inputs[prop].quantity
shape = v[0].values.numpy().shape
assert np.allclose(
v[0].values.numpy(),
ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape)
* (10 if prop == "velocity" else 1),
)
Loading