Source code for FOX.armc.package_manager

"""A module containing the :class:`PackageManager` class.

Index
-----
.. currentmodule:: FOX.armc
.. autosummary::
    PackageManagerABC
    PackageManager

API
---
.. autoclass:: PackageManagerABC
    :members:
.. autoclass:: PackageManager
    :members:

"""

from __future__ import annotations

import os
import shutil
import textwrap
from abc import ABC, abstractmethod
from logging import Logger
from itertools import chain, zip_longest
from collections import abc
from typing import (
    Mapping, TypeVar, Hashable, Any, KeysView, ItemsView, ValuesView, Iterator, overload,
    Union, Dict, List, Optional, Tuple, Iterable, Sequence, cast, TYPE_CHECKING
)

import numpy as np
import pandas as pd
from scm.plams import config, Molecule, JobManager  # type: ignore
from qmflows import Settings as QmSettings
from qmflows.cp2k_utils import prm_to_df
from qmflows.packages import CP2K, CP2K_Result
from noodles import gather, schedule, has_scheduled_methods, run_parallel
from nanoutils import set_docstring, TypedDict

from ..classes import MultiMolecule
from ..functions.cp2k_utils import get_xyz_path
from ..logger import DummyLogger
from ..io.read_xyz import XYZError

if TYPE_CHECKING:
    from qmflows.packages import Result, Package
    from noodles.interface import PromisedObject
else:
    from ..type_alias import PromisedObject, Result, Package

__all__ = ['PackageManagerABC', 'PackageManager']


class PkgDict(TypedDict):
    """A :class:`~typing.TypedDict` representing a single job recipe."""

    type: Package
    molecule: Molecule
    settings: QmSettings


T = TypeVar('T')
RT = TypeVar('RT', bound=Result)

MolLike = Iterable[Tuple[float, float, float]]
Value = Tuple[PkgDict, ...]

#: The internal dictionary contained within :class:`PackageManagerABC`.
Data = Dict[str, Value]

DataMap = Mapping[str, Iterable[PkgDict]]
DataIter = Iterable[Tuple[str, Iterable[PkgDict]]]

JobHook = Iterator[Iterable[Result]]


[docs]class PackageManagerABC(ABC, Mapping[str, Value]): """A class for managing qmflows-style jobs.""" _data: Data _hook: Optional[JobHook] def __init__(self, data: Union[DataMap, DataIter], hook: Optional[JobHook] = None, **kwargs: Any) -> None: r"""Initialize an instance. Parameters ---------- data : :class:`~collections.abc.Mapping` [:class:`str`, :class:`~collections.abc.Iterable` [:class:`~scm.plams.core.basejob.SingleJob`]] A mapping with user-defined job descriptor as keys and an iterable of Job instances as values. hook : :class:`~collections.abc.Iterator` [:class:`~collections.abc.Iterable` [:class:`~qmflows.packages.Result`]], optional An iterator yielding multiple qmflows Result objects. Can be used as a hook for the purpose of unit-testing. **kwargs : :data:`~typing.Any` Further keyword arguments which can be customized by :class:`PackageManagerABC` subclasses. See Also -------- func:`evaluate_rmsd` Evaluate the RMSD of a geometry optimization. """ # noqa: E501 if kwargs: name = next(iter(kwargs)) raise TypeError(f"Unexpected argument {name!r}") super().__init__() self.data = cast(Data, data) self.hook = hook # Attributes and properties @property def hook(self) -> Optional[JobHook]: """Get or set the :attr:`hook` attribute.""" return self._hook @hook.setter def hook(self, value: Optional[JobHook]) -> None: if value is None: pass elif not isinstance(value, abc.Iterator): raise TypeError("'hook' excpected an iterator; " f"observed type: {value.__class__.__name__!r}") self._hook = value @property def data(self) -> Data: """A property containing this instance's underlying :class:`dict`. The getter will simply return the attribute's value. The setter will validate and assign any mapping or iterable containing of key/value pairs. """ return self._data @data.setter def data(self, value: Union[DataMap, DataIter]) -> None: # noqa iterable = value.items() if isinstance(value, abc.Mapping) else value ret = {k: tuple(v) for k, v in iterable} value_len = {len(v) for v in ret.values()} if not value: raise ValueError("'data' expected a non-empty Mapping") elif len(value_len) != 1: raise ValueError("All values passed to 'data' must be of the same length") # Ensure all settings are qmflows.Settings instances for job_tup in ret.values(): for job in job_tup: job['settings'] = QmSettings(job['settings']) self._data = ret def __eq__(self, value: Any) -> bool: """Implement :code:`self == value`.""" if type(self) is not type(value): return False iterator: Iterator[Tuple[PkgDict, PkgDict]] iterator = chain.from_iterable(zip_longest(v, value[k]) for k, v in self.items()) ret = True try: for job1, job2 in iterator: if None in (job1, job2): return False for k, v1 in job1.items(): v2 = job2[k] if isinstance(v1, Molecule): ret &= (np.asarray(v1) == np.asarray(v2)).all() except KeyError: return False else: return ret def __repr__(self) -> str: """Implement :code:`repr(self)` and :code:`str(self)`.""" data = '' for k, v in self.items(): end = '}, ...)' if len(v) > 1 else '},)' data += ( f',\n{k!r}' + ': ({' + ', '.join(f'{_k!r}: {_v.__class__.__name__}(...)' for _k, _v in v[0].items()) + end ) return f'{self.__class__.__name__}(' + '{\n' + textwrap.indent(data[2:], 4 * ' ') + '\n})' # Mapping implementation def __getitem__(self, key: str) -> Value: """Implement :code:`self[key]`.""" return self.data[key] def __iter__(self) -> Iterator[str]: """Implement :code:`iter(self)`.""" return iter(self.data) def __len__(self) -> int: """Implement :code:`len(self)`.""" return len(self.data) def __contains__(self, key: Any) -> bool: """Implement :code:`key in self`.""" return key in self.data
[docs] def keys(self) -> KeysView[str]: """Return a set-like object providing a view of this instance's keys.""" return self.data.keys()
[docs] def items(self) -> ItemsView[str, Value]: """Return a set-like object providing a view of this instance's key/value pairs.""" return self.data.items()
[docs] def values(self) -> ValuesView[Value]: """Return an object providing a view of this instance's values.""" return self.data.values()
@overload def get(self, key: Hashable) -> Optional[Value]: ... @overload def get(self, key: Hashable, default: T) -> Union[Value, T]: ...
[docs] def get(self, key, default=None): # noqa: E301 """Return the value for **key** if it's available; return **default** otherwise.""" return self.data.get(key, default) # type: ignore
# The actual job runner @abstractmethod def __call__( self, logger: Optional[Logger] = None, **kwargs: Any ) -> Union[Tuple[None, None], Tuple[List[MultiMolecule], List[Any]]]: r"""Run all jobs and return a sequence of user-specified results. Parameters ---------- logger : :class:`logging.Logger`, optional A logger for reporting the updated value. \**kwargs : :data:`~typing.Any` Keyword arguments which can be further customized in a sub-class. Returns ------- :class:`~collections.abc.Sequence`, optional Returns ``None`` if one of the jobs crashed; a Sequence of user-specified objects is returned otherwise. The nature of the to-be returned objects should be defined in a sub-class. """ raise NotImplementedError('Trying to call an abstract method')
[docs] @staticmethod @abstractmethod def assemble_job(job: PkgDict, **kwargs: Any) -> Any: """Assemble a :class:`PkgDict` into an actual job.""" raise NotImplementedError('Trying to call an abstract method')
[docs] @abstractmethod def clear_jobs(self, **kwargs: Any) -> None: """Delete all jobs located in :attr:`_job_cache`.""" raise NotImplementedError('Trying to call an abstract method')
[docs] @abstractmethod def update_settings(self, dct_seq: Sequence[dict[str, pd.DataFrame]]) -> None: """Update the Settings embedded in this instance using **dct**.""" raise NotImplementedError('Trying to call an abstract method')
def to_yaml_dict(self) -> Dict[str, Any]: if self.hook is not None: raise NotImplementedError cls = type(self) ret: Dict[str, Any] = { 'type': f'{cls.__module__}.{cls.__name__}', 'molecule': [], } for k, v_tup in self.items(): lst: List[Dict[Any, Any]] = [] for v in v_tup: dct = v['settings'].as_dict() for k_param, v_param in list(dct.items()): if isinstance(v_param, pd.DataFrame): del dct[k_param] lst.append(dct) ret[k] = {'type': f'qmflows.{v["type"].pkg_name}', 'settings': lst} return ret
[docs]@set_docstring(PackageManagerABC.__doc__) @has_scheduled_methods class PackageManager(PackageManagerABC): def __init__(self, data: Union[DataMap, DataIter], hook: Optional[JobHook] = None) -> None: super().__init__(data, hook) # Transform all forcefield parameter blocks into DataFrames job_iterator = (job['settings'] for job in chain.from_iterable(self.values())) for settings in job_iterator: # Type: QmSettings prm_to_df(settings) def __call__( self, logger: Optional[Logger] = None, n_processes: int = 1 ) -> Union[Tuple[None, None], Tuple[List[MultiMolecule], List[Any]]]: r"""Run all jobs and return a sequence of list of MultiMolecules. Parameters ---------- logger : :class:`logging.Logger`, optional A logger for reporting job statuses. Returns ------- :class:`list` [:class:`FOX.MultiMolecule`], optional Returns ``None`` if one of the jobs crashed; a list of MultiMolecule is returned otherwise. """ # Construct the logger if logger is None: logger = cast(Logger, DummyLogger()) # Check if a hook has been specified if self.hook is not None: results = next(self.hook) return self._extract_mol(results, logger) jobs_iter = iter(self.items()) name, jobs = next(jobs_iter) promised_jobs: List[PromisedObject] = [self.assemble_job(j, name=name) for j in jobs] for name, jobs in jobs_iter: promised_jobs = [self.assemble_job(j, p_j, name=name) for j, p_j in zip(jobs, promised_jobs)] results = run_parallel(gather(*promised_jobs), n_threads=n_processes) return self._extract_mol(results, logger)
[docs] @staticmethod @schedule def assemble_job( job: PkgDict, old_results: Optional[Result] = None, name: Optional[str] = None, ) -> PromisedObject: """Create a :class:`PromisedObject` from a qmflow :class:`Package` instance.""" job_name = name if name is not None else '' obj_type = job['type'] settings = job['settings'] if old_results is None: mol = job['molecule'] else: mol = old_results.geometry if isinstance(obj_type, CP2K) and isinstance(old_results, CP2K_Result): try: lattice: np.ndarray[Any, np.dtype[np.float64]] = old_results.lattice assert lattice is not None except (AssertionError, FileNotFoundError): pass else: settings = settings.copy() settings.cell_parameters = lattice[-1].tolist() return obj_type(mol=mol, job_name=job_name, validate_output=False, settings=settings)
[docs] @staticmethod def clear_jobs() -> None: """Delete all jobs.""" job_manager: JobManager = config.default_jobmanager workdir: Union[str, os.PathLike] = job_manager.workdir for job in job_manager.jobs: name = os.path.join(workdir, job.name) try: shutil.rmtree(name) except FileNotFoundError: pass job_manager.jobs = [] job_manager.names = {}
[docs] def update_settings(self, dct_seq: Sequence[dict[str, pd.DataFrame]]) -> None: """Update all forcefield parameter blocks in this instance's CP2K settings.""" for job_list in self.values(): for job, dct in zip(job_list, dct_seq): job['settings'].update(dct)
@overload @staticmethod def _extract_mol(results: None, logger: Logger) -> Tuple[None, None]: ... @overload # noqa: E301 @staticmethod def _extract_mol( results: Iterable[RT], logger: Logger ) -> Tuple[None, None] | Tuple[List[MultiMolecule], List[RT]]: ... @staticmethod # noqa: E301 def _extract_mol( results: None | Iterable[RT], logger: Logger, ) -> Tuple[None, None] | Tuple[List[MultiMolecule], List[RT]]: """Create a list of MultiMolecule from the passed **results**.""" # `noodles.run_parallel()` can return `None` under certain circumstances if results is None: return None, None mol_list = [] results_list = list(results) for result in results_list: if result.status in {'failed', 'crashed'}: return None, None try: lattice: None | np.ndarray[Any, np.dtype[np.float64]] = result.lattice assert lattice is not None except (AssertionError, FileNotFoundError): lattice = None try: # Construct and return a MultiMolecule object path: str = get_xyz_path(result.archive['work_dir']) # type: ignore mol = MultiMolecule.from_xyz(path) mol.lattice = lattice mol.round(3, inplace=True) mol_list.append(mol) except XYZError: # The .xyz file is unreadable for some reason logger.warning(f"Failed to parse {path!r}") return None, None except Exception as ex: logger.warning(ex) return None, None return mol_list, results_list