Source code for FOX.recipes.param

"""A set of functions for analyzing and plotting ARMC results.

Examples
--------
A general overview of the functions within this module.

.. code:: python

    >>> import pandas as pd
    >>> from FOX.recipes import get_best, overlay_descriptor, plot_descriptor

    >>> hdf5_file: str = ...

    >>> param: pd.Series = get_best(hdf5_file, name='param')  # Extract the best parameters
    >>> rdf: pd.DataFrame = get_best(hdf5_file, name='rdf')  # Extract the matching RDF

    # Compare the RDF to its reference RDF and plot
    >>> rdf_dict = overlay_descriptor(hdf5_file, name='rdf')
    >>> plot_descriptor(rdf_dict)

.. image:: rdf.png
    :scale: 20 %
    :align: center


Examples
--------
A small workflow for calculating for calculating free energies using distribution functions
such as the radial distribution function (RDF).

.. code:: python

    >>> import pandas as pd
    >>> from FOX import get_free_energy
    >>> from FOX.recipes import get_best, overlay_descriptor, plot_descriptor

    >>> hdf5_file: str = ...

    >>> rdf: pd.DataFrame = get_best(hdf5_file, name='rdf')
    >>> G: pd.DataFrame = get_free_energy(rdf, unit='kcal/mol')

    >>> rdf_dict = overlay_descriptor(hdf5_file, name='rdf)
    >>> G_dict = {key: get_free_energy(value) for key, value in rdf_dict.items()}
    >>> plot_descriptor(G_dict)

.. image:: G_rdf.png
    :scale: 20 %
    :align: center


Examples
--------
A workflow for plotting parameters as a function of ARMC iterations.

.. code:: python

    >>> import numpy as np
    >>> import pandas as pd
    >>> from FOX import from_hdf5
    >>> from FOX.recipes import plot_descriptor

    >>> hdf5_file: str = ...

    >>> param: pd.DataFrame = from_hdf5(hdf5_file, 'param')
    >>> param.index.name = 'ARMC iteration'
    >>> param_dict = {key: param[key] for key in param.columns.levels[0]}

    >>> plot_descriptor(param_dict)

.. image:: param.png
    :scale: 20 %
    :align: center

This approach can also be used for the plotting of other properties such as the auxiliary error.

.. code:: python

    >>> ...

    >>> err: pd.DataFrame = from_hdf5(hdf5_file, 'aux_error')
    >>> err.index.name = 'ARMC iteration'
    >>> err_dict = {'Auxiliary Error': err}

    >>> plot_descriptor(err_dict)

.. image:: err.png
    :scale: 20 %
    :align: center

On occasion it might be desirable to only print the error of, for example, accepted iterations.
Given a sequence of booleans (``bool_seq``), one can slice a DataFrame or Series (``df``) using
:code:`df.loc[bool_seq]`.

.. code:: python

    >>> ...

    >>> acceptance: np.ndarray = from_hdf5(hdf5_file, 'acceptance')  # Boolean array
    >>> err_slice_dict = {key: df.loc[acceptance], value for key, df in err_dict.items()}

    >>> plot_descriptor(err_slice_dict)


Index
-----
.. currentmodule:: FOX.recipes
.. autosummary::
    get_best
    overlay_descriptor
    plot_descriptor

API
---
.. autofunction:: get_best
.. autofunction:: overlay_descriptor
.. autofunction:: plot_descriptor

"""

from __future__ import annotations

from os import PathLike
from typing import Dict, Union, Iterable, Any, Tuple, Iterator, cast, Optional, Mapping, List
from collections import abc

import h5py
import numpy as np
import pandas as pd
from pandas.core.generic import NDFrame
from nanoutils import raise_if

from FOX import from_hdf5
from FOX.logger import DEFAULT_LOGGER as logger

try:
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import Figure
    PLT_ERROR: Optional[ImportError] = None
except ImportError as ex:
    from FOX.type_alias import Figure
    PLT_ERROR = ex

__all__ = ['get_best', 'overlay_descriptor', 'plot_descriptor']

PlotAccessor: type = pd.DataFrame.plot  # A class used by Pandas for plotting stuff


[docs]def get_best( hdf5_file: str | PathLike[str], name: str, i: int = 0, sum_error: None | str | List[str] = None, err_dset: str = 'aux_error' ) -> pd.DataFrame: """Return the PES descriptor or ARMC property which yields the lowest error. Parameters ---------- hdf5_file : :class:`str` The path+filename of the ARMC .hdf5 file. name : :class:`str` The name of the PES descriptor, *e.g.* ``"rdf"``. Alternatively one can supply an ARMC property such as ``"acceptance"``, ``"param"`` or ``"aux_error"``. i : :class:`int` The index of the desired PES. Only relevant for PES-descriptors of state-averaged ARMCs. sum_error : :class:`str` or :class:`list[str] <list>`, optional Sum all the given aux errors for a given iteration when determining an optimum. If :data:`None`, sum over all aux errors. err_dset : :class:`str` The name of the dataset containing the errors. Generally speaking one should pick either ``"aux_error"`` or ``"validation/aux_error"``. Returns ------- :class:`pandas.DataFrame` or :class:`pd.Series` A DataFrame of the optimal PES descriptor or other (user-specified) ARMC property. """ name = name.strip('/') full_name = f'{name}.{i}' with h5py.File(hdf5_file, 'r', libver='latest') as f: if full_name not in f.keys(): # i.e. if **name** does not belong to a PE descriptor full_name = name shape = f[err_dset].shape[:2] # Load the DataFrames if full_name in err_dset: aux_error = prop = from_hdf5(hdf5_file, err_dset) else: hdf5_dict = from_hdf5(hdf5_file, [err_dset, full_name]) aux_error, prop = hdf5_dict[err_dset], hdf5_dict[full_name] # Return the best DataFrame (or Series) if sum_error is None: j: int = aux_error.sum(axis=1, skipna=False).idxmin() else: j = aux_error[sum_error].sum(axis=1, skipna=False).idxmin() logger.debug(f"Optimum ARMC cycle: {np.unravel_index(j, shape)}") df = prop[j] if not isinstance(prop, NDFrame) else prop.iloc[j] if isinstance(df, pd.DataFrame): df.columns.name = full_name elif isinstance(df, pd.Series): df.name = full_name return df
[docs]def overlay_descriptor(hdf5_file: Union[str, 'PathLike[str]'], name: str = 'rdf', i: int = 0, err_dset: str = 'aux_error') -> Dict[str, pd.DataFrame]: """Return the PES descriptor which yields the lowest error and overlay it with the reference PES descriptor. Parameters ---------- hdf5_file : :class:`str` The path+filename of the ARMC .hdf5 file. name : :class:`str` The name of the PES descriptor, *e.g.* ``"rdf"``. i : :class:`int` The index of desired PES. Only relevant for state-averaged ARMCs. err_dset : :class:`str` The name of the dataset containing the errors. Generally speaking one should pick either ``"aux_error"`` or ``"validation/aux_error"``. Returns ------- :class:`dict` [:class:`str`, :class:`pandas.DataFrame`] A dictionary of DataFrames. Values consist of DataFrames with two keys: ``"MM-MD"`` and ``"QM-MD"``. Atom pairs, such as ``"Cd Cd"``, are used as keys. """ # noqa with h5py.File(hdf5_file, 'r', libver='latest') as f: shape = f[err_dset].shape[:2] mm_name = f'{name}.{i}' qm_name = f'{name}.{i}.ref' hdf5_dict = from_hdf5(hdf5_file, [err_dset, mm_name, qm_name]) aux_error, mm, qm = hdf5_dict[err_dset], hdf5_dict[mm_name], hdf5_dict[qm_name] j: int = aux_error.sum(axis=1, skipna=False).idxmin() logger.debug(f"Optimum ARMC cycle: {np.unravel_index(j, shape)}") mm = mm[j] qm = qm[0] ret = {} for key in mm: df = pd.DataFrame({'MM-MD': mm[key], 'QM-MD': qm[key]}, index=mm.index) df.columns.name = mm_name ret[key] = df return ret
[docs]@raise_if(PLT_ERROR) def plot_descriptor(descriptor: Union[NDFrame, Iterable[NDFrame]], show_fig: bool = True, kind: str = 'line', sharex: bool = True, sharey: bool = False, **kwargs: Any) -> Figure: r"""Plot a DataFrame or iterable consisting of one or more DataFrames. Requires the matplotlib_ package. .. _matplotlib: https://matplotlib.org/ Parameters ---------- descriptor : :class:`pandas.DataFrame` or :class:`Iterable<collections.abc.Iterable>` [:class:`pandas.DataFrame`] A DataFrame or an iterable consisting of DataFrames. show_fig : :class:`bool` Whether to show the figure or not. kind : :class:`str` The plot kind to-be passed to :meth:`pandas.DataFrame.plot`. sharex/sharey : :class:`bool` Whether or not the to-be created plots should share their x/y-axes. \**kwargs : :data:`Any<typing.Any>` Further keyword arguments for the :meth:`pandas.DataFrame.plot` method. Returns ------- :class:`Figure<matplotlib.figure.Figure>` A matplotlib Figure. See Also -------- :func:`get_best` Return the PES descriptor or ARMC property which yields the lowest error. :func:`overlay_descriptor` Return the PES descriptor which yields the lowest error and overlay it with the reference PES descriptor. """ # noqa kind_ = _validate_kind(kind) ncols, iterator = _get_df_iterator(descriptor) figsize = (4 * ncols, 6) fig, ax_tup = plt.subplots(ncols=ncols, sharex=sharex, sharey=sharey) if ncols == 1: # Ensure ax_tup is actually a tuple ax_tup = (ax_tup,) # Construct the actual plots for (key, df), ax in zip(iterator, ax_tup): if isinstance(key, tuple): key = ' '.join(repr(i) for i in key) df.plot(ax=ax, title=key, figsize=figsize, kind=kind_, **kwargs) if show_fig: plt.show(block=True) return fig
def _validate_kind(kind: str) -> str: """Validate the **kind** parameter for :func:`plot_descriptor`.""" try: ret = kind.lower() except AttributeError as ex: raise TypeError("'kind' expected a 'str'; observed type: " f"'{kind.__class__.__name__}'").with_traceback(ex.__traceback__) return ret def _get_df_iterator(descriptor: Union[NDFrame, Mapping[Any, NDFrame], Iterable[NDFrame]] ) -> Tuple[int, Iterator[Tuple[Any, NDFrame]]]: """Return the number of plots and a DataFrame enumerator for :func:`plot_descriptor`.""" if isinstance(descriptor, pd.Series): descriptor = descriptor.to_frame() # Figure out the number of plots try: ncols: int = len(descriptor.keys()) # type: ignore[union-attr] except AttributeError: try: ncols = len(descriptor) # type: ignore except TypeError as ex: if not isinstance(descriptor, abc.Iterable): tb = ex.__traceback__ raise TypeError("'descriptor' expected an iterable; observed type: " f"'{descriptor.__class__.__name__}'").with_traceback(tb) descriptor = list(descriptor) ncols = len(descriptor) # Construct an iterator of 2-tuples try: iterator = cast(Iterator[Tuple[Any, NDFrame]], descriptor.items()) # type: ignore[union-attr] # noqa: E501 except (AttributeError, TypeError): iterator = enumerate(descriptor) return ncols, iterator