"""
FOX.recipes.param
=================
A set of functions for analyzing and plotting ARMC results.
Examples
--------
A general overview of the functions within this module.
.. code:: python
>>> import pandas as pd
>>> from FOX.recipes import get_best, overlay_descriptor, plot_descriptor
>>> hdf5_file: str = ...
>>> param: pd.Series = get_best(hdf5_file, name='param') # Extract the best parameters
>>> rdf: pd.DataFrame = get_best(hdf5_file, name='rdf') # Extract the matching RDF
# Compare the RDF to its reference RDF and plot
>>> rdf_dict = overlay_descriptor(hdf5_file, name='rdf')
>>> plot_descriptor(rdf_dict)
.. image:: rdf.png
:scale: 20 %
:align: center
Examples
--------
A small workflow for calculating for calculating free energies using distribution functions
such as the radial distribution function (RDF).
.. code:: python
>>> import pandas as pd
>>> from FOX import get_free_energy
>>> from FOX.recipes import get_best, overlay_descriptor, plot_descriptor
>>> hdf5_file: str = ...
>>> rdf: pd.DataFrame = get_best(hdf5_file, name='rdf')
>>> G: pd.DataFrame = get_free_energy(rdf, unit='kcal/mol')
>>> rdf_dict = overlay_descriptor(hdf5_file, name='rdf)
>>> G_dict = {key: get_free_energy(value) for key, value in rdf_dict.items()}
>>> plot_descriptor(G_dict)
.. image:: G_rdf.png
:scale: 20 %
:align: center
Examples
--------
A workflow for plotting parameters as a function of ARMC iterations.
.. code:: python
>>> import numpy as np
>>> import pandas as pd
>>> from FOX import from_hdf5
>>> from FOX.recipes import plot_descriptor
>>> hdf5_file: str = ...
>>> param: pd.DataFrame = from_hdf5(hdf5_file, 'param')
>>> param.index.name = 'ARMC iteration'
>>> param_dict = {key: param[key] for key in param.columns.levels[0]}
>>> plot_descriptor(param_dict)
.. image:: param.png
:scale: 20 %
:align: center
|
This approach can also be used for the plotting of other properties such as the auxiliary error.
.. code:: python
>>> ...
>>> err: pd.DataFrame = from_hdf5(hdf5_file, 'aux_error')
>>> err.index.name = 'ARMC iteration'
>>> err_dict = {'Auxiliary Error': err}
>>> plot_descriptor(err_dict)
.. image:: err.png
:scale: 20 %
:align: center
|
On occasion it might be desirable to only print the error of, for example, accepted iterations.
Given a sequence of booleans (``bool_seq``), one can slice a DataFrame or Series (``df``) using
:code:`df.loc[bool_seq]`.
.. code:: python
>>> ...
>>> acceptance: np.ndarray = from_hdf5(hdf5_file, 'acceptance') # Boolean array
>>> err_slice_dict = {key: df.loc[acceptance], value for key, df in err_dict.items()}
>>> plot_descriptor(err_slice_dict)
Index
-----
.. currentmodule:: FOX.recipes.param
.. autosummary::
get_best
overlay_descriptor
plot_descriptor
API
---
.. autofunction:: get_best
.. autofunction:: overlay_descriptor
.. autofunction:: plot_descriptor
"""
from typing import Dict, Union, Iterable, Any, FrozenSet, Tuple, Iterator, Hashable, Optional
from collections import abc
import numpy as np
import pandas as pd
from pandas.core.generic import NDFrame
try:
import matplotlib.pyplot as plt
PltFigure: Union[type, str] = plt.Figure
PLT_ERROR: Optional[str] = None
except ImportError:
PltFigure = 'matplotlib.pyplot.Figure'
PLT_ERROR = (
"Use of the FOX.{} function requires the 'matplotlib' package."
"\n'matplotlib' can be installed via PyPi with the following command:"
"\n\tpip install matplotlib"
)
try:
import h5py
H5PY_ERROR: Optional[str] = None
except ImportError:
H5PY_ERROR = (
"Use of the FOX.{} function requires the 'h5py' package."
"\n'h5py' can be installed via conda with the following command:"
"\n\tconda install -n FOX -c conda-forge h5py"
)
from FOX import from_hdf5, assert_error
from FOX.logger import DEFAULT_LOGGER as logger
__all__ = ['get_best', 'overlay_descriptor', 'plot_descriptor']
PlotAccessor: type = pd.DataFrame.plot # A class used by Pandas for plotting stuff
[docs]@assert_error(H5PY_ERROR)
def get_best(hdf5_file: str, name: str = 'rdf', i: int = 0) -> pd.DataFrame:
"""Return the PES descriptor or ARMC property which yields the lowest error.
Parameters
----------
hdf5_file : :class:`str`
The path+filename of the ARMC .hdf5 file.
name : :class:`str`
The name of the PES descriptor, *e.g.* ``"rdf"``.
Alternatively one can supply an ARMC property such as ``"acceptance"``,
``"param"`` or ``"aux_error"``.
i : :class:`int`
The index of the desired PES.
Only relevant for PES-descriptors of state-averaged ARMCs.
Returns
-------
:class:`pandas.DataFrame` or :class:`pd.Series`
A DataFrame of the optimal PES descriptor or other (user-specified) ARMC property.
"""
full_name = f'{name}.{i}'
with h5py.File(hdf5_file, 'r', libver='latest') as f:
if full_name not in f.keys(): # i.e. if **name** does not belong to a PE descriptor
full_name = name
shape = f['aux_error'].shape[:2]
# Load the DataFrames
if full_name == 'aux_error':
aux_error = prop = from_hdf5(hdf5_file, 'aux_error')
else:
hdf5_dict = from_hdf5(hdf5_file, ['aux_error', full_name])
aux_error, prop = hdf5_dict['aux_error'], hdf5_dict[full_name]
# Return the best DataFrame (or Series)
j: int = aux_error.sum(axis=1, skipna=False).idxmin()
logger.debug(f"Optimum ARMC cycle: {np.unravel_index(j, shape)}")
df = prop[j] if not isinstance(prop, NDFrame) else prop.iloc[j]
if isinstance(df, pd.DataFrame):
df.columns.name = full_name
elif isinstance(df, pd.Series):
df.name = full_name
return df
[docs]@assert_error(H5PY_ERROR)
def overlay_descriptor(hdf5_file: str, name: str = 'rdf', i: int = 0) -> Dict[str, pd.DataFrame]:
"""Return the PES descriptor which yields the lowest error and overlay it with the reference PES descriptor.
Parameters
----------
hdf5_file : :class:`str`
The path+filename of the ARMC .hdf5 file.
name : :class:`str`
The name of the PES descriptor, *e.g.* ``"rdf"``.
i : :class:`int`
The index of desired PES.
Only relevant for state-averaged ARMCs.
Returns
-------
:class:`dict` [:class:`str`, :class:`pandas.DataFrame`]
A dictionary of DataFrames.
Values consist of DataFrames with two keys: ``"MM-MD"`` and ``"QM-MD"``.
Atom pairs, such as ``"Cd Cd"``, are used as keys.
""" # noqa
with h5py.File(hdf5_file, 'r', libver='latest') as f:
shape = f['aux_error'].shape[:2]
mm_name = f'{name}.{i}'
qm_name = f'{name}.{i}.ref'
hdf5_dict = from_hdf5(hdf5_file, ['aux_error', mm_name, qm_name])
aux_error, mm, qm = hdf5_dict['aux_error'], hdf5_dict[mm_name], hdf5_dict[qm_name]
j: int = aux_error.sum(axis=1, skipna=False).idxmin()
logger.debug(f"Optimum ARMC cycle: {np.unravel_index(j, shape)}")
mm = mm[j]
qm = qm[0]
ret = {}
for key in mm:
df = pd.DataFrame({'MM-MD': mm[key], 'QM-MD': qm[key]}, index=mm.index)
df.columns.name = mm_name
ret[key] = df
return ret
[docs]@assert_error(PLT_ERROR)
def plot_descriptor(descriptor: Union[NDFrame, Iterable[NDFrame]],
show_fig: bool = True, kind: str = 'line',
sharex: bool = True, sharey: bool = False,
**kwargs: Any) -> PltFigure:
r"""Plot a DataFrame or iterable consisting of one or more DataFrames.
Requires the matplotlib_ package.
.. _matplotlib: https://matplotlib.org/
Parameters
----------
descriptor : :class:`pandas.DataFrame` or :class:`Iterable<collections.abc.Iterable>` [:class:`pandas.DataFrame`]
A DataFrame or an iterable consisting of DataFrames.
show_fig : :class:`bool`
Whether to show the figure or not.
kind : :class:`str`
The plot kind to-be passed to :meth:`pandas.DataFrame.plot`.
sharex/sharey : :class:`bool`
Whether or not the to-be created plots should share their x/y-axes.
\**kwargs : :data:`Any<typing.Any>`
Further keyword arguments for the :meth:`pandas.DataFrame.plot` method.
Returns
-------
:class:`Figure<matplotlib.figure.Figure>`
A matplotlib Figure.
See Also
--------
:func:`get_best`
Return the PES descriptor or ARMC property which yields the lowest error.
:func:`overlay_descriptor`
Return the PES descriptor which yields the lowest error and
overlay it with the reference PES descriptor.
""" # noqa
kind_ = _validate_kind(kind)
ncols, iterator = _get_df_iterator(descriptor)
figsize = (4 * ncols, 6)
fig, ax_tup = plt.subplots(ncols=ncols, sharex=sharex, sharey=sharey)
if ncols == 1: # Ensure ax_tup is actually a tuple
ax_tup = (ax_tup,)
# Construct the actual plots
for (key, df), ax in zip(iterator, ax_tup):
if isinstance(key, tuple):
key = ' '.join(repr(i) for i in key)
df.plot(ax=ax, title=key, figsize=figsize, kind=kind_, **kwargs)
if show_fig:
plt.show(block=True)
return fig
#: A :class:`frozenset` with valid values for the **kind** parameter in :func:`plot_descriptor`.
VALID_KIND: FrozenSet[str] = frozenset(
PlotAccessor._all_kinds + tuple(PlotAccessor._kind_aliases.values())
)
def _validate_kind(kind: str) -> str:
"""Validate the **kind** parameter for :func:`plot_descriptor`."""
try:
ret = kind.lower()
except AttributeError as ex:
raise TypeError("'kind' expected a 'str'; observed type: "
f"'{kind.__class__.__name__}'").with_traceback(ex.__traceback__)
if ret not in VALID_KIND:
raise ValueError(f"{repr(ret)} is not a valid value for 'kind'; "
f"accepted values: {tuple(sorted(VALID_KIND))}")
return ret
def _get_df_iterator(descriptor: Union[NDFrame, Iterable[NDFrame]]
) -> Tuple[int, Iterator[Tuple[Hashable, NDFrame]]]:
"""Return the number of plots and a DataFrame enumerator for :func:`plot_descriptor`."""
if isinstance(descriptor, pd.Series):
descriptor = descriptor.to_frame()
# Figure out the number of plots
try:
ncols = len(descriptor.keys())
except AttributeError:
try:
ncols = len(descriptor)
except TypeError as ex:
if not isinstance(descriptor, abc.Iterable):
tb = ex.__traceback__
raise TypeError("'descriptor' expected an iterable; observed type: "
f"'{descriptor.__class__.__name__}'").with_traceback(tb)
descriptor = list(descriptor)
ncols = len(descriptor)
# Construct an iterator of 2-tuples
try:
iterator = descriptor.items()
except (AttributeError, TypeError):
iterator = enumerate(descriptor)
return ncols, iterator