Source code for FOX.armc.monte_carlo

"""A module for performing Monte Carlo-based forcefield parameter optimizations.

Index
-----
.. currentmodule:: FOX.armc
.. autosummary::
    MonteCarloABC

API
---
.. autoclass:: MonteCarloABC
    :members:

"""

from __future__ import annotations

import copy
import warnings
from os import PathLike
from abc import ABC, abstractmethod
from logging import Logger
from functools import wraps, partial
from itertools import repeat, cycle, chain
from collections import abc
from typing import (
    Tuple, List, Dict, Optional, Union, Iterable, Hashable, Iterator, Any, Mapping, Callable,
    KeysView, ValuesView, ItemsView, Sequence, TypeVar, overload, TYPE_CHECKING, cast, ClassVar
)

import numpy as np
from assertionlib.dataclass import AbstractDataClass
from nanoutils import EMPTY_MAPPING
from qmflows.packages import Result
from qmflows.warnings_qmflows import QMFlows_Warning

from ..logger import DEFAULT_LOGGER
from ..type_hints import ArrayOrScalar

if TYPE_CHECKING:
    from numpy.typing import ArrayLike
    from .package_manager import PackageManager
    from .param_mapping import ParamMapping
    from ..classes import MultiMolecule
else:
    from ..type_alias import PackageManager, ParamMapping, MultiMolecule

__all__ = ['MonteCarloABC']

T = TypeVar('T')

PostProcess = Callable[[Optional[Iterable[MultiMolecule]], Optional['MonteCarloABC']], None]
GetPesDescriptor = Callable[[MultiMolecule, Result], ArrayOrScalar]
Key = Tuple[float, ...]


def _template_func1(_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
    _mol, _, *args = args
    return _func(_mol, *args, **kwargs)


def _template_func2(_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
    _, _result, *args = args
    return _func(_result, *args, **kwargs)


[docs]class MonteCarloABC(AbstractDataClass, ABC, Mapping[Key, np.ndarray]): r"""The base :class:`.MonteCarloABC` class.""" param: ParamMapping package_manager: PackageManager keep_files: bool hdf5_file: Union[str, 'PathLike[str]'] pes: Dict[str, GetPesDescriptor] pes_validation: Dict[str, GetPesDescriptor] swap_phi: Optional[Callable[..., Any]] HAS_LOOP: ClassVar[bool] = False @property def molecule(self) -> Tuple[MultiMolecule, ...]: """Get **value** or set **value** as a tuple of MultiMolecule instances.""" return self._molecule @molecule.setter def molecule(self, value: Iterable[MultiMolecule]) -> None: self._molecule = tuple(value) @property def logger(self) -> Logger: """Get or set the logger.""" return self._logger @logger.setter def logger(self, value: Optional[Logger]) -> None: if value is not None: self._logger = value else: self._logger = DEFAULT_LOGGER @property def pes_post_process(self) -> Tuple[PostProcess, ...]: """Get or set post-processing functions.""" return self._pes_post_process @pes_post_process.setter def pes_post_process(self, value: Optional[Iterable[PostProcess]]) -> None: if value is not None: self._pes_post_process = tuple(value) else: self._pes_post_process = () def __init__(self, molecule: Iterable[MultiMolecule], package_manager: PackageManager, param: ParamMapping, keep_files: bool = False, hdf5_file: Union[str, PathLike] = 'armc.hdf5', logger: Optional[Logger] = None, pes_post_process: Optional[Iterable[PostProcess]] = None, **kwargs: Any) -> None: """Initialize a :class:`MonteCarlo` instance.""" if kwargs: name = next(iter(kwargs)) raise TypeError(f"Unexpected argument {name!r}") super().__init__() self.param = param # Settings for running the actual MD calculations self.molecule = cast(Tuple[MultiMolecule, ...], molecule) self.package_manager = package_manager self.keep_files = keep_files self.pes_post_process = cast(Tuple[PostProcess, ...], pes_post_process) # HDF5 settings self.hdf5_file = hdf5_file # Logging settings self.logger = cast(Logger, logger) # Internally set attributes self.pes = {} self.pes_validation = {} self._data: Dict[Key, np.ndarray] = {} @AbstractDataClass.inherit_annotations() def _str_iterator(self): iterator = ((k.strip('_'), v) for k, v in super()._str_iterator()) return sorted(iterator) def __eq__(self, value: Any) -> bool: """Implement :code:`self == value`.""" if type(self) is not type(value): return False elif self.keys() != value.keys(): return False names = ('package_manager', 'param', 'phi') for name in names: if getattr(self, name) != getattr(value, name): return False if not np.allclose(self.molecule, value.molecule): return False if self.molecule[0].lattice is not None: lat0 = np.array([m.lattice for m in self.molecule], dtype=np.float64) lat1 = np.array([m.lattice for m in value.molecule], dtype=np.float64) if not np.allclose(lat0, lat1): return False elif value.molecule[0].lattice is not None: return False iterator1 = ((v, value.pes[k]) for k, v in self.pes.items()) for p1, p2 in iterator1: # type: partial, partial # type: ignore if p1.func != p2.func or p1.keywords != p2.keywords: return False iterator2 = ((v, value[k]) for k, v in self.items()) return all((v1 == v2).all() for v1, v2 in iterator2) # Implementation of the Mapping protocol def __setitem__(self, key: Key, value: np.ndarray) -> None: """Implement :code:`self[key] = value`.""" self._data[key] = value def __getitem__(self, key: Key) -> np.ndarray: """Implement :code:`self[key]`.""" return self._data[key] def __iter__(self) -> Iterator[Key]: """Implement :code:`iter(self)`.""" return iter(self._data) def __len__(self) -> int: """Implement :code:`len(self)`.""" return len(self._data) def __contains__(self, key: Any) -> bool: """Implement :code:`key in self`.""" return key in self._data
[docs] def keys(self) -> KeysView[Key]: """Return a set-like object providing a view of this instance's keys.""" return self._data.keys()
[docs] def items(self) -> ItemsView[Key, np.ndarray]: """Return a set-like object providing a view of this instance's key/value pairs.""" return self._data.items()
[docs] def values(self) -> ValuesView[np.ndarray]: """Return an object providing a view of this instance's values.""" return self._data.values()
@overload def get(self, key: Hashable) -> Optional[np.ndarray]: ... @overload def get(self, key: Hashable, default: T) -> Union[np.ndarray, T]: ...
[docs] def get(self, key, default=None): # noqa: E301 """Return the value for **key** if it's available; return **default** otherwise.""" return self._data.get(key, default)
# Monte Carlo stuff @overload def add_pes_evaluator( self, name: str, func: GetPesDescriptor, err_func: Callable[[ArrayLike, ArrayLike], np.number | float], args: Sequence[Any], kwargs: Mapping[str, Any] = ..., validation: bool = ..., ref: Optional[Sequence[np.ndarray]] = None, weight: float | Sequence[float] = ..., ) -> None: ... @overload # noqa: E301 def add_pes_evaluator( self, name: str, func: GetPesDescriptor, err_func: Callable[[ArrayLike, ArrayLike], np.number | float], args: Sequence[Any], kwargs: Iterable[Mapping[str, Any]], validation: bool = ..., ref: Optional[Sequence[np.ndarray]] = None, weight: float | Sequence[float] = ..., ) -> None: ...
[docs] def add_pes_evaluator(self, name, func, err_func, args=(), kwargs=EMPTY_MAPPING, validation=False, ref=None, weight=0.0): # noqa: E301, E501 r"""Add a callable to this instance for constructing PES-descriptors. Examples -------- .. code:: python >>> from FOX import MonteCarlo, MultiMolecule >>> mc = MonteCarlo(...) >>> mol = MultiMolecule.from_xyz(...) # Prepare arguments >>> name = 'rdf' >>> func = FOX.MultiMolecule.init_rdf >>> atom_subset = ['Cd', 'Se', 'O'] # Keyword argument for func # Add the PES-descriptor constructor >>> mc.add_pes_evaluator(name, func, kwargs={'atom_subset': atom_subset}) Parameters ---------- name : :class:`str` The name under which the PES-descriptor will be stored (*e.g.* ``"RDF"``). func : :class:`~Collections.abc.Callable` The callable for constructing the PES-descriptor. The callable should take an array-like object as input and return a new array-like object as output. err_func : :class:`~Collections.abc.Callable` The function for computing the auxilary error. args : :class:`~collections.abc.Sequence` A sequence of positional arguments. kwargs : :class:`dict` or :class:`Iterable[dict] <collections.abc.Iterable>` A dictionary or an iterable of dictionaries with keyword arguments. Providing an iterable allows one to use a unique set of keyword arguments for each molecule in :attr:`MonteCarlo.molecule`. validation : :class:`bool` Whether the PES-descriptor is used exclusively for validation or not. """ mol_list = [m.copy() for _ in self.param.move_range for m in self.molecule] for f1 in self.pes_post_process: f1(mol_list, self) if ref is None: ref_iter: Iterable[Optional[np.ndarray]] = repeat(None) else: ref_iter = [ar for _ in self.param.move_range for ar in ref] weight_iter = repeat(weight) if not isinstance(weight, abc.Iterable) else weight if not isinstance(kwargs, abc.Mapping): _Iterator = Iterator[Tuple[MultiMolecule, Any, Mapping[str, Any], float]] iterator: _Iterator = zip(mol_list, ref_iter, kwargs, weight_iter) else: iterator = zip(mol_list, ref_iter, repeat(kwargs, len(mol_list)), weight_iter) for i, (mol, ref_, kwarg, w) in enumerate(iterator): if ref_ is None: f2 = wraps(func)(partial(_template_func1, func, *args, **kwarg)) f2.ref = f2(mol, None) f2.use_mol = True f2.err_func = err_func f2.weight = w else: f2 = wraps(func)(partial(_template_func2, func, *args, **kwarg)) f2.ref = copy.deepcopy(ref_) f2.use_mol = False f2.err_func = err_func f2.weight = w # Check that a numeric value is returned dtype = np.asanyarray(f2.ref).dtype if dtype.kind not in 'buifc': raise TypeError(f"PES descriptor {name!r} has an invalid return dtype: {dtype}") if validation: self.pes_validation[f'{name}.{i}'] = f2 else: self.pes[f'{name}.{i}'] = f2
@abstractmethod def __call__(self, **kwargs: Any) -> None: raise NotImplementedError('Trying to call an abstract method') def restart(self, **kwargs: Any) -> None: raise NotImplementedError('Method not implemented') def to_yaml_dict(self, **kwargs: Any) -> Dict[str, Any]: raise NotImplementedError('Method not implemented') @property def clear_jobs(self) -> Callable[[], None]: """Delete all cp2k output files.""" return self.package_manager.clear_jobs
[docs] def run_jobs(self) -> Union[Tuple[None, None], Tuple[List[MultiMolecule], List[Any]]]: """Run a geometry optimization followed by a molecular dynamics (MD) job. Returns a new :class:`~FOX.MultiMolecule` instance constructed from the MD trajectory and the path to the MD results. If no trajectory is available (*i.e.* the job crashed) return *None* instead. * The MD job is constructed according to the provided settings in **self.job**. Returns ------- :class:`list[FOX.MultiMolecule] <list>`, optional A list of MultiMolecule instance(s) constructed from the MD trajectory. Will return :data:`None` if one of the jobs crashed """ return self.package_manager(logger=self.logger)
[docs] def move(self, idx: Optional[int] = None) -> Union[Exception, Key]: """Update a random parameter in **self.param** by a random value from **self.move.range**. Performs in inplace update of the ``'param'`` column in **self.param**. By default the move is applied in a multiplicative manner. **self.job.md_settings** and **self.job.preopt_settings** are updated to reflect the change in parameters. Examples -------- .. code:: python >>> print(armc.param['param']) charge Br -0.731687 Cs 0.731687 epsilon Br Br 1.045000 Cs Br 0.437800 Cs Cs 0.300000 sigma Br Br 0.421190 Cs Br 0.369909 Cs Cs 0.592590 Name: param, dtype: float64 >>> for _ in range(1000): # Perform 1000 random moves >>> armc.move() >>> print(armc.param['param']) charge Br -0.597709 Cs 0.444592 epsilon Br Br 0.653053 Cs Br 1.088848 Cs Cs 1.025769 sigma Br Br 0.339293 Cs Br 0.136361 Cs Cs 0.101097 Name: param, dtype: float64 Parameters ---------- idx : :class:`int`, optional The column key for :attr:`param_mapping["param"]<MonteCarloABC.param_mapping.>`. Returns ------- :class:`tuple[float, ...] <float>` A tuple with the (new) values in the ``'param'`` column of **self.param**. """ idx_: int = idx or 0 # Perform the move ret = self.param(logger=self.logger, param_idx=idx_) if isinstance(ret, Exception): return ret else: key, prm_name, _ = ret prm_update = self.param.param.loc[(key, prm_name), idx_].to_frame().T prm_update.index = [prm_name] if idx is None: _iterator = chain.from_iterable(self.package_manager.values()) iterator = (job['settings'] for job in _iterator) else: iterator = (job_tup[idx_]['settings'] for job_tup in self.package_manager.values()) # Update the job settings for settings in iterator: settings[key].update(prm_update) return cast(Key, tuple(self.param.param[idx_].values))
[docs] def get_pes_descriptors( self, get_first_key: bool = False, ) -> Tuple[ Dict[str, ArrayOrScalar], Dict[str, ArrayOrScalar], Optional[List[MultiMolecule]] ]: """Check if a **key** is already present in **history_dict**. If :data:`True`, return the matching list of PES descriptors; If :data:`False`, construct and return a new list of PES descriptors. * The PES descriptors are constructed by the provided settings in **self.pes**. Parameters ---------- get_first_key : :class:`bool` Keep both the files and the job_cache if this is the first ARMC iteration. Usefull for manual inspection in case cp2k hard-crashes at this point. Returns ------- :class:`dict[str, np.ndarray[np.float64]] <dict>`, :class:`dict[str, np.ndarray[np.float64]] <dict>` and :class:`list[FOX.MultiMolecule] <list>` A previous value from **history_dict** or a new value from an MD calculation & a :class:`~FOX.MultiMolecule` instance constructed from the MD simulation. Values are set to ``np.inf`` if the MD job crashed. """ # noqa: E501 # Generate PES descriptors mol_list, result_list = self.run_jobs() if mol_list is not None and result_list is not None: with warnings.catch_warnings(): warnings.simplefilter("error", QMFlows_Warning) self.logger.info("Applying PES post-processing") for func1 in self.pes_post_process: func1(mol_list, self) # Post-process the MultiMolecules ret1: Dict[str, ArrayOrScalar] = {} iterator1 = zip(self.pes.items(), cycle(mol_list), cycle(result_list)) for (k, func2), mol, result in iterator1: _k, i = k.rsplit(".", maxsplit=1) self.logger.info(f"Calculating descriptor {_k!r} for PES {i}") try: ret1[k] = func2(mol, result) except Exception as ex: raise RuntimeError( f"Failed to compute descriptor {_k!r} for PES {i}" ) from ex ret2: Dict[str, ArrayOrScalar] = {} iterator2 = zip(self.pes_validation.items(), cycle(mol_list), cycle(result_list)) for (k, func2), mol, result in iterator2: _k, i = k.rsplit(".", maxsplit=1) self.logger.info(f"Calculating validation descriptor {_k!r} for PES {i}") try: ret2[k] = func2(mol, result) except Exception as ex: raise RuntimeError( f"Failed to compute validation descriptor {_k!r} for PES {i}" ) from ex else: # The MD simulation crashed ret1 = {key: np.inf for key in self.pes.keys()} ret2 = {key: np.inf for key in self.pes_validation.keys()} if not (get_first_key or self.keep_files): self.logger.info("Clearing jobs") self.clear_jobs() return ret1, ret2, mol_list