Skip to main content

get_calculator

from mlip_arena.tasks.utils import get_calculator
Builds and returns an ASE BaseCalculator for a given model. Automatically selects the freest available GPU via get_freer_device unless you pass an explicit device. When dispersion=True, wraps the base calculator in an ASE SumCalculator that adds a DFT-D3 dispersion correction via TorchDFTD3Calculator.
def get_calculator(
    calculator: str | MLIPEnum | BaseCalculator,
    calculator_kwargs: dict | None = None,
    dispersion: bool = False,
    dispersion_kwargs: dict | None = None,
    device: str | None = None,
) -> BaseCalculator:

Parameters

calculator
str | MLIPEnum | BaseCalculator
required
The model to use. Accepts:
  • An MLIPEnum member (e.g. MLIPEnum["MACE-MP(M)"]).
  • A string name matching a registered model (e.g. "MACE-MP(M)"). Looked up via MLIPEnum[calculator].
  • A BaseCalculator subclass (not an instance). Instantiated with calculator_kwargs.
  • An existing BaseCalculator instance. Returned as-is; calculator_kwargs are ignored.
Raises ValueError if none of the above match.
calculator_kwargs
dict | None
default:"{}"
Keyword arguments forwarded to the calculator constructor. The device key is automatically injected, so you do not need to pass it here.
dispersion
bool
default:"False"
When True, adds a DFT-D3 dispersion correction on top of the base calculator using TorchDFTD3Calculator from the torch_dftd package. The base calculator and the dispersion calculator are combined with ASE’s SumCalculator.Requires torch_dftd to be installed. Raises ImportError if the package is missing.
dispersion_kwargs
dict | None
Keyword arguments forwarded to TorchDFTD3Calculator. Defaults to:
{"damping": "bj", "xc": "pbe", "cutoff": 40.0 * ase.units.Bohr}
The device key is automatically injected. Common keys:
KeyDefaultDescription
damping"bj"Damping scheme. "bj" (Becke-Johnson) or "zero".
xc"pbe"Exchange-correlation functional used when the model was trained.
cutoff40.0 * units.BohrReal-space cutoff for the D3 sum (~21 Å).
device
str | None
default:"None"
Device string passed to the calculator, e.g. "cuda:0" or "cpu". When None, the device is chosen automatically by get_freer_device().

Return value

calc
ase.calculators.calculator.BaseCalculator
required
A ready-to-use ASE calculator. When dispersion=False this is the model’s own calculator. When dispersion=True this is an ASE SumCalculator([model_calc, disp_calc]).

Examples

Basic usage

from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.utils import get_calculator

# Look up by enum member
calc = get_calculator(MLIPEnum["MACE-MP(M)"])

# Equivalent: look up by name string
calc = get_calculator("MACE-MP(M)")

With DFT-D3 dispersion correction

from ase import units
from ase.build import bulk
from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.utils import get_calculator

atoms = bulk("Cu", cubic=True)

calc = get_calculator(
    calculator=MLIPEnum["MACE-MP(M)"],
    dispersion=True,
    dispersion_kwargs={
        "damping": "bj",
        "xc": "pbe",
        "cutoff": 40.0 * units.Bohr,
    },
)

atoms.calc = calc
print(atoms.get_potential_energy())

Explicit device and calculator kwargs

from mlip_arena.models import MLIPEnum
from mlip_arena.tasks.utils import get_calculator

calc = get_calculator(
    calculator=MLIPEnum["CHGNet"],
    calculator_kwargs={"use_device": "cuda:1"},
    device="cuda:1",
)

Custom calculator class

from ase.calculators.lj import LennardJones
from mlip_arena.tasks.utils import get_calculator

# Pass the class, not an instance — get_calculator will instantiate it
calc = get_calculator(LennardJones, calculator_kwargs={"sigma": 2.5})
When you pass an existing BaseCalculator instance, calculator_kwargs are silently ignored. Pass the class if you want kwargs forwarded.

Registered models (MLIPEnum)

MLIPEnum is an Enum whose members are populated at import time from registry.yaml. Each member’s .value is the calculator class for that model.
from mlip_arena.models import MLIPEnum

# List all registered models
print(list(MLIPEnum.__members__.keys()))
# ['MACE-MP(M)', 'CHGNet', 'M3GNet', 'MatterSim', 'ORBv2',
#  'SevenNet', 'eqV2(OMat)', 'MACE-MPA', 'eSEN', ...]

# Access a member
model = MLIPEnum["MACE-MP(M)"]
print(model.name)   # 'MACE-MP(M)'
print(model.value)  # <class 'MACE_MP_Medium'>
Only models whose packages are installed will appear in MLIPEnum. Missing packages produce a warning at import time and the model is skipped.