Skip to content

Commit d089e1e

Browse files
jameskermodeclaude
andcommitted
Refactor ase-ace Python interface for consistency and clarity
- Remove redundant ACECalculatorLazy class (ACECalculator already lazy) - Make ACECalculator inherit from ACECalculatorBase with NotImplementedError stubs for cutoff/species/n_basis/get_descriptors (socket protocol limitation) - Extract duplicate neighbor grouping logic to _group_neighbors() helper - Fix mutable default argument in ACELibraryCalculator.calculate() - Fix misleading threading docstring in batch_energy_forces_virial() - Add __repr__ methods to all three calculator classes - Update README with per-calculator parameter tables and descriptor docs - Document ase_ace.utils helper functions - Remove empty [ipi] optional dependency from pyproject.toml 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent fe5c905 commit d089e1e

File tree

6 files changed

+175
-51
lines changed

6 files changed

+175
-51
lines changed

export/ase-ace/README.md

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Install the required Julia packages in the ase-ace Julia environment:
5050

5151
```bash
5252
# Navigate to the ase-ace directory
53-
cd path/to/ACEpotentials.jl/ase-ace
53+
cd path/to/ACEpotentials.jl/export/ase-ace
5454

5555
# Install Julia dependencies
5656
julia --project=julia -e '
@@ -85,10 +85,9 @@ pip install -e ".[dev]"
8585
```
8686

8787
**Installation options:**
88-
- `ase-ace` - Base package only
88+
- `ase-ace` - Base package only (includes `ACECalculator`)
8989
- `ase-ace[julia]` - Adds `juliacall` and `juliapkg` for `ACEJuliaCalculator`
9090
- `ase-ace[lib]` - Adds `matscipy` for `ACELibraryCalculator`
91-
- `ase-ace[ipi]` - For `ACECalculator` (no extra Python deps, just Julia)
9291
- `ase-ace[all]` - All optional dependencies
9392

9493
## Quick Start
@@ -169,21 +168,75 @@ print(f"Energy: {energy:.4f} eV")
169168

170169
**Note**: Install with library support: `pip install ase-ace[lib]`
171170

171+
## Computing ACE Descriptors
172+
173+
The `get_descriptors()` method returns the raw ACE basis vectors for each atom,
174+
useful for fitting, analysis, and transfer learning.
175+
176+
**Availability:** `ACEJuliaCalculator` and `ACELibraryCalculator` only.
177+
`ACECalculator` does not support descriptors (socket protocol limitation).
178+
179+
### Example
180+
181+
```python
182+
from ase.build import bulk
183+
from ase_ace import ACELibraryCalculator
184+
185+
atoms = bulk('Si', 'diamond', a=5.43) * (2, 2, 2)
186+
calc = ACELibraryCalculator("deployment/lib/libace_model.so")
187+
188+
# Get descriptors for all atoms
189+
descriptors = calc.get_descriptors(atoms)
190+
print(f"Shape: {descriptors.shape}") # (natoms, n_basis)
191+
192+
# Access model properties
193+
print(f"Cutoff: {calc.cutoff} Å")
194+
print(f"Species: {calc.species}") # Atomic numbers
195+
print(f"Basis size: {calc.n_basis}")
196+
```
197+
198+
### Properties
199+
200+
| Property | Type | Description |
201+
|----------|------|-------------|
202+
| `cutoff` | float | Cutoff radius in Angstroms |
203+
| `species` | List[int] | Supported atomic numbers |
204+
| `n_basis` | int | Number of basis functions per atom |
205+
206+
### Use Cases
207+
208+
- **Linear model verification**: For linear ACE, `E = sum(descriptors @ weights)`
209+
- **Transfer learning**: Use descriptors as features for other ML models
210+
- **Analysis**: Examine local atomic environments
211+
172212
## Configuration
173213

174-
### Constructor Parameters
214+
### ACECalculator Parameters
175215

176216
| Parameter | Type | Default | Description |
177217
|-----------|------|---------|-------------|
178218
| `model_path` | str | required | Path to ACE model JSON file |
179-
| `num_threads` | int/str | 'auto' | Julia threads: 1, 2, 4, 8, or 'auto' |
180-
| `port` | int | 0 | TCP port (0 = auto-assign) |
181-
| `unixsocket` | str | None | Unix socket name (faster for local) |
182-
| `timeout` | float | 60.0 | Connection timeout in seconds |
183-
| `julia_executable` | str | 'julia' | Path to Julia executable |
184-
| `julia_project` | str | None | Julia project path (default: bundled) |
219+
| `num_threads` | int/str | 'auto' | Julia threads |
220+
| `port` | int | 0 | TCP port (0 = auto) |
221+
| `unixsocket` | str | None | Unix socket name |
222+
| `timeout` | float | 60.0 | Connection timeout (seconds) |
223+
| `julia_executable` | str | 'julia' | Path to Julia |
224+
| `julia_project` | str | None | Julia project path |
185225
| `log_level` | str | 'WARNING' | Logging level |
186226

227+
### ACEJuliaCalculator Parameters
228+
229+
| Parameter | Type | Default | Description |
230+
|-----------|------|---------|-------------|
231+
| `model_path` | str | required | Path to ACE model JSON file |
232+
| `num_threads` | int/str | 'auto' | Julia threads |
233+
234+
### ACELibraryCalculator Parameters
235+
236+
| Parameter | Type | Default | Description |
237+
|-----------|------|---------|-------------|
238+
| `library_path` | str | required | Path to compiled .so file |
239+
187240
### Threading
188241

189242
The calculator uses Julia's multi-threading for parallel ACE evaluation:
@@ -360,6 +413,29 @@ If using a specific port that's in use:
360413
calc = ACECalculator('model.json', port=0) # Auto-assign port
361414
```
362415

416+
## Utility Functions
417+
418+
The `ase_ace.utils` module provides helper functions for Julia setup:
419+
420+
```python
421+
from ase_ace.utils import find_julia, check_julia_version, setup_julia_environment
422+
423+
# Find Julia executable
424+
julia_path = find_julia()
425+
426+
# Check Julia version
427+
major, minor, patch = check_julia_version()
428+
429+
# Set up Julia environment with required packages
430+
setup_julia_environment(verbose=True)
431+
```
432+
433+
**Available functions:**
434+
- `find_julia()` - Locate Julia executable in PATH
435+
- `check_julia_version(julia_executable)` - Get Julia version as (major, minor, patch) tuple
436+
- `check_julia_packages(julia_executable, julia_project)` - Check if required packages are installed
437+
- `setup_julia_environment(julia_executable, julia_project, verbose)` - Install and configure Julia dependencies
438+
363439
## Running Tests
364440

365441
```bash

export/ase-ace/pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ julia = [
4242
"juliapkg>=0.1.10",
4343
]
4444

45-
# ACECalculator - i-PI socket backend (no extra Python deps, just Julia install)
46-
ipi = []
47-
4845
# Development dependencies
4946
dev = [
5047
"pytest>=7.0",

export/ase-ace/src/ase_ace/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class ACECalculatorBase(Calculator, ABC):
1717
Abstract base class for ACE calculators.
1818
1919
Provides a unified API across different backends:
20-
- ACELibraryCalculator (compiled .so via ctypes)
21-
- ACEJuliaCalculator (JuliaCall)
22-
- ACECalculator (socket-based, partial support)
20+
- ACELibraryCalculator (compiled .so via ctypes) - full support
21+
- ACEJuliaCalculator (JuliaCall) - full support
22+
- ACECalculator (socket-based) - energy/forces/stress only, no descriptors
2323
2424
Subclasses must implement the abstract properties and methods.
2525

export/ase-ace/src/ase_ace/calculator.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
from pathlib import Path
1414
from typing import Optional, Union, List
1515

16-
from ase.calculators.calculator import Calculator, all_changes
16+
import numpy as np
17+
from ase.calculators.calculator import all_changes
1718

19+
from .base import ACECalculatorBase
1820
from .server import JuliaACEServer, find_free_port
1921

2022
logger = logging.getLogger(__name__)
2123

2224

23-
class ACECalculator(Calculator):
25+
class ACECalculator(ACECalculatorBase):
2426
"""
2527
ASE calculator for ACE potentials via Julia/IPICalculator.
2628
@@ -78,7 +80,6 @@ class ACECalculator(Calculator):
7880
IPICalculator.jl installed.
7981
"""
8082

81-
implemented_properties = ['energy', 'forces', 'stress']
8283
default_parameters = {}
8384

8485
def __init__(
@@ -92,7 +93,7 @@ def __init__(
9293
julia_project: Optional[str] = None,
9394
log_level: str = 'WARNING',
9495
):
95-
Calculator.__init__(self)
96+
super().__init__()
9697

9798
# Configure logging
9899
logging.basicConfig(level=getattr(logging, log_level.upper()))
@@ -183,7 +184,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
183184
if properties is None:
184185
properties = self.implemented_properties
185186

186-
Calculator.calculate(self, atoms, properties, system_changes)
187+
super().calculate(atoms, properties, system_changes)
187188

188189
# Start server on first calculation
189190
if not self._started:
@@ -237,13 +238,43 @@ def actual_port(self) -> Optional[int]:
237238
"""The actual TCP port being used (after start)."""
238239
return self._actual_port
239240

241+
# Abstract method implementations (not supported for socket-based calculator)
240242

241-
class ACECalculatorLazy(ACECalculator):
242-
"""
243-
Lazy-initialized ACECalculator.
243+
@property
244+
def cutoff(self) -> float:
245+
"""Cutoff radius in Angstroms."""
246+
raise NotImplementedError(
247+
"cutoff not available for socket-based ACECalculator. "
248+
"Use ACEJuliaCalculator or ACELibraryCalculator instead."
249+
)
244250

245-
Same as ACECalculator but delays Julia startup until the first
246-
calculation is requested. This is the default behavior of
247-
ACECalculator.
248-
"""
249-
pass
251+
@property
252+
def species(self) -> List[int]:
253+
"""List of supported atomic numbers."""
254+
raise NotImplementedError(
255+
"species not available for socket-based ACECalculator. "
256+
"Use ACEJuliaCalculator or ACELibraryCalculator instead."
257+
)
258+
259+
@property
260+
def n_basis(self) -> int:
261+
"""Number of basis functions per atom."""
262+
raise NotImplementedError(
263+
"n_basis not available for socket-based ACECalculator. "
264+
"Use ACEJuliaCalculator or ACELibraryCalculator instead."
265+
)
266+
267+
def get_descriptors(self, atoms) -> np.ndarray:
268+
"""
269+
Compute ACE descriptors (basis values) for all atoms.
270+
271+
Not supported for socket-based ACECalculator.
272+
Use ACEJuliaCalculator or ACELibraryCalculator instead.
273+
"""
274+
raise NotImplementedError(
275+
"get_descriptors() not available for socket-based ACECalculator. "
276+
"Use ACEJuliaCalculator or ACELibraryCalculator instead."
277+
)
278+
279+
def __repr__(self) -> str:
280+
return f"ACECalculator(model={self.model_path.name}, threads={self.num_threads})"

export/ase-ace/src/ase_ace/julia_calculator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,7 @@ def calculate(
240240
properties = self.implemented_properties
241241

242242
# Call parent to set self.atoms
243-
from ase.calculators.calculator import Calculator
244-
Calculator.calculate(self, atoms, properties, system_changes)
243+
super().calculate(atoms, properties, system_changes)
245244

246245
self._ensure_initialized()
247246
jl = self._jl
@@ -339,3 +338,6 @@ def get_descriptors(self, atoms) -> np.ndarray:
339338
)(at_jl, self._model)
340339

341340
return np.array(descriptors_jl)
341+
342+
def __repr__(self) -> str:
343+
return f"ACEJuliaCalculator(model={self.model_path.name}, threads={self._num_threads})"

export/ase-ace/src/ase_ace/library_calculator.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pathlib import Path
1818
from typing import Optional, List
1919

20-
from ase.calculators.calculator import Calculator, all_changes
20+
from ase.calculators.calculator import all_changes
2121

2222
from .base import ACECalculatorBase
2323

@@ -189,7 +189,7 @@ def batch_energy_forces_virial(
189189
"""
190190
Compute energies, forces, and virials for multiple atoms.
191191
192-
Uses threading when JULIA_NUM_THREADS > 1.
192+
Note: Threading is NOT available in --trim=safe compiled libraries.
193193
"""
194194
natoms = len(z)
195195
total_neighbors = len(neighbor_z)
@@ -282,16 +282,14 @@ class ACELibraryCalculator(ACECalculatorBase):
282282
>>> print(f"Energy: {atoms.get_potential_energy():.4f} eV")
283283
"""
284284

285-
implemented_properties = ['energy', 'forces', 'stress']
286-
287285
def __init__(self, library_path: str, **kwargs):
288286
if not HAS_MATSCIPY:
289287
raise ImportError(
290288
"matscipy is required for ACELibraryCalculator. "
291289
"Install it with: pip install matscipy"
292290
)
293291

294-
Calculator.__init__(self, **kwargs)
292+
super().__init__(**kwargs)
295293
self.ace = ACELibrary(library_path)
296294

297295
@property
@@ -309,6 +307,30 @@ def n_basis(self) -> int:
309307
"""Number of basis functions (descriptors per atom)."""
310308
return self.ace.n_basis
311309

310+
def _group_neighbors(self, i_idx: np.ndarray, natoms: int) -> tuple:
311+
"""
312+
Group neighbor list by center atom.
313+
314+
Parameters
315+
----------
316+
i_idx : np.ndarray
317+
Center atom indices from neighbor list
318+
natoms : int
319+
Total number of atoms
320+
321+
Returns
322+
-------
323+
tuple
324+
(atom_starts, atom_ends) arrays indicating neighbor ranges
325+
"""
326+
if len(i_idx) > 0:
327+
atom_starts = np.searchsorted(i_idx, np.arange(natoms))
328+
atom_ends = np.searchsorted(i_idx, np.arange(natoms), side='right')
329+
else:
330+
atom_starts = np.zeros(natoms, dtype=int)
331+
atom_ends = np.zeros(natoms, dtype=int)
332+
return atom_starts, atom_ends
333+
312334
def get_descriptors(self, atoms) -> np.ndarray:
313335
"""
314336
Compute ACE descriptors (basis values) for all atoms.
@@ -353,12 +375,7 @@ def get_descriptors(self, atoms) -> np.ndarray:
353375
)
354376

355377
# Group neighbors by center atom
356-
if len(i_idx) > 0:
357-
atom_starts = np.searchsorted(i_idx, np.arange(natoms))
358-
atom_ends = np.searchsorted(i_idx, np.arange(natoms), side='right')
359-
else:
360-
atom_starts = np.zeros(natoms, dtype=int)
361-
atom_ends = np.zeros(natoms, dtype=int)
378+
atom_starts, atom_ends = self._group_neighbors(i_idx, natoms)
362379

363380
# Compute descriptors for each atom
364381
descriptors = np.zeros((natoms, self.n_basis), dtype=np.float64)
@@ -379,13 +396,16 @@ def get_descriptors(self, atoms) -> np.ndarray:
379396

380397
return descriptors
381398

382-
def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes):
399+
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
383400
"""
384401
Perform calculation for given atoms object.
385402
386-
Uses matscipy neighbor list and batch API with threading.
403+
Uses matscipy neighbor list and site-level API.
387404
"""
388-
Calculator.calculate(self, atoms, properties, system_changes)
405+
if properties is None:
406+
properties = self.implemented_properties
407+
408+
super().calculate(atoms, properties, system_changes)
389409

390410
natoms = len(atoms)
391411
numbers = atoms.get_atomic_numbers()
@@ -398,12 +418,7 @@ def calculate(self, atoms=None, properties=['energy'], system_changes=all_change
398418
)
399419

400420
# Group neighbors by center atom
401-
if len(i_idx) > 0:
402-
atom_starts = np.searchsorted(i_idx, np.arange(natoms))
403-
atom_ends = np.searchsorted(i_idx, np.arange(natoms), side='right')
404-
else:
405-
atom_starts = np.zeros(natoms, dtype=int)
406-
atom_ends = np.zeros(natoms, dtype=int)
421+
atom_starts, atom_ends = self._group_neighbors(i_idx, natoms)
407422

408423
neighbor_counts = (atom_ends - atom_starts).astype(np.int32)
409424

@@ -466,3 +481,6 @@ def calculate(self, atoms=None, properties=['energy'], system_changes=all_change
466481
self.results['stress'] = stress
467482
else:
468483
self.results['stress'] = np.zeros(6)
484+
485+
def __repr__(self) -> str:
486+
return f"ACELibraryCalculator(lib={self.ace.lib_path.name})"

0 commit comments

Comments
 (0)