Source code for

"""A class for reading and CHARMM .rtf topology files.

.. currentmodule:: FOX
.. autosummary::

.. autoclass:: RTFContainer
    :members: mass, atom, bond, impropers, angles, dihedrals, charmm_version, auto

.. automethod:: RTFContainer.collapse_charges
.. automethod:: RTFContainer.auto_to_explicit
.. automethod:: RTFContainer.from_file
.. automethod:: RTFContainer.concatenate

from __future__ import annotations

import os
import types
import textwrap
import itertools
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from import Mapping, Iterator, Iterable
from collections import defaultdict

import h5py
import numpy as np
import pandas as pd
from scm.plams import Molecule, Atom

from . import FileIter
from ..functions.molecule_utils import get_angles, get_dihedrals

    from typing_extensions import Self
    from numpy.typing import NDArray

__all__ = ["RTFContainer"]

[docs]class RTFContainer: """A class for managing CHARMM .rtf topology files. Examples -------- .. code:: python >>> from FOX import RTFContainer >>> input_file = str(...) >>> rtf = RTFContainer.from_file(input_file) """ __slots__ = ( "__weakref__", "mass", "atom", "bond", "impr", "angles", "dihe", "charmm_version", "auto", "_pd_printoptions", ) #: A dataframe holding all MASS-related info. mass: pd.DataFrame #: A dataframe holding all ATOM-related info. atom: pd.DataFrame #: A dataframe holding all BOND-related info. bond: pd.DataFrame #: A dataframe holding all IMPR-related info. impr: pd.DataFrame #: A dataframe holding all ANGLES-related info. angles: pd.DataFrame #: A dataframe holding all DIHE-related info. dihe: pd.DataFrame #: The CHARMM version used for generating the .rtf file charmm_version: tuple[int, ...] #: A set with all .rtf statements that should be auto-generated. auto: set[str] #: Print options as used by :meth:`~RTFContainer.__repr__`. _pd_printoptions: dict[str, Any] #: A mapping with strucutred dtypes for each dataframe column and index. DTYPES: ClassVar[types.MappingProxyType[str, np.dtype[np.void]]] = types.MappingProxyType({ "MASS": np.dtype([ ("index", "i8"), ("atom_type", "U5"), ("mass", "f8"), ("atom_name", "U2"), ]), "ATOM": np.dtype([ ("molecule", "U5"), ("atom1", "i8"), ("atom_type", "U5"), ("charge", "f8"), ]), "BOND": np.dtype([ ("molecule", "U5"), ("atom1", "i8"), ("atom2", "i8"), ]), "ANGLES": np.dtype([ ("molecule", "U5"), ("atom1", "i8"), ("atom2", "i8"), ("atom3", "i8"), ]), "DIHE": np.dtype([ ("molecule", "U5"), ("atom1", "i8"), ("atom2", "i8"), ("atom3", "i8"), ("atom4", "i8"), ]), "IMPR": np.dtype([ ("molecule", "U5"), ("atom1", "i8"), ("atom2", "i8"), ("atom3", "i8"), ("atom4", "i8"), ]), }) @property def impropers(self) -> pd.DataFrame: """A dataframe holding all IMPR-related info.""" return self.impr @impropers.setter def impropers(self, value: pd.DataFrame) -> None: self.impr = value @property def dihedrals(self) -> pd.DataFrame: """A dataframe holding all DIHE-related info.""" return self.dihe @dihedrals.setter def dihedrals(self, value: pd.DataFrame) -> None: self.dihe = value @property def pd_printoptions(self) -> Iterator[Any]: """Return an iterator flattening :attr:`_pd_printoptions`.""" return itertools.chain.from_iterable(self._pd_printoptions.items()) @property def residues(self) -> pd.Index: """Get all unique residue names.""" return self.atom.index[~self.atom.index.duplicated()] def __init__( self, mass: pd.DataFrame, atom: pd.DataFrame, bond: pd.DataFrame, impr: pd.DataFrame, angles: pd.DataFrame, dihe: pd.DataFrame, charmm_version: tuple[int, ...] = (0, 0), auto: None | set[str] = None, ) -> None: """Initialize the instance.""" self.mass = mass self.atom = atom = bond self.impropers = impr self.angles = angles self.dihedrals = dihe self.charmm_version = charmm_version = auto if auto is not None else set() self._pd_printoptions = {"display.max_rows": 20} def __eq__(self, other: object) -> bool: """Implement :meth:`self == other <object.__eq__>`.""" cls = type(self) if not isinstance(other, cls): return NotImplemented if != return False df_keys = ["mass", "atom", "bond", "impropers", "angles", "dihedrals"] iterator = ((getattr(self, k), getattr(other, k)) for k in df_keys) return all(df1.equals(df2) for df1, df2 in iterator) def __reduce__(self) -> tuple[type[Self], tuple[Any, ...]]: """Helper function for :mod:`pickle`.""" cls = type(self) return cls, ( self.mass, self.atom,, self.impropers, self.angles, self.dihedrals, self.charmm_version,, ) def __repr__(self) -> str: """Implement :func:`repr(self)<repr>`.""" # Get all to-be printed attribute (names) cls = type(self) attr_names = ["mass", "atom", "bond", "impropers", "angles", "dihedrals"] # Determine the indentation width width = max(len(k) for k in attr_names) indent = width + 3 # Gather string representations of all attributes ret = "" with pd.option_context(*self.pd_printoptions): items = ((k, getattr(self, k)) for k in attr_names) for k, _v in items: v = textwrap.indent(repr(_v), " " * indent)[indent:] ret += f"{k:{width}} = {v},\n" ret += f"{'auto':{width}} = {!r},\n" ret += f"{'charmm_version':{width}} = {self.charmm_version!r},\n" return f"{cls.__name__}(\n{textwrap.indent(ret[:-2], 4 * ' ')}\n)"
[docs] def collapse_charges(self) -> dict[str, float]: """Return a dictionary mapping atom types to atomic charges. Returns ------- dict[str, float] Raises ------ ValueError: Raised if an atom type has multiple unique charges associated with it """ dct: dict[str, set[float]] = defaultdict(set) for at, charge in zip(self.atom["atom_type"], self.atom["charge"].round(6)): dct[at].add(charge) illegal = {k: sorted(v) for k, v in dct.items() if len(v) > 1} if illegal: raise ValueError( f"Found {len(illegal)} atom types with two or more " f"distinct charges: {illegal!r}" ) return {k: v.pop() for k, v in dct.items()}
[docs] def auto_to_explicit(self) -> None: """Convert all statements in :attr:`` into explicit dataframe.""" if not return # Construct a dictionary mapping residue names to PLAMS molecules (with bonds) atom_dict: dict[str, str] = dict(zip(self.mass["atom_type"], self.mass["atom_name"])) mol_dict: dict[str, Molecule] = {} for res in self.residues: mol_dict[res] = mol = Molecule() for at_type in self.atom["atom_type"]: mol.add_atom(Atom(symbol=atom_dict[at_type])) for (i, j) in zip(["atom1"],["atom2"]): mol.add_bond(mol[i], mol[j]) # Generate angles and/or proper dihedral angles based on the AUTO settings if "ANGLES" in self.angles = self._auto_to_explicit("ANGLES", mol_dict)"ANGLES") if "DIHE" in self.dihedrals = self._auto_to_explicit("DIHE", mol_dict)"DIHE") if warnings.warn(f"Unsupported auto statements: {sorted(!r}", stacklevel=2)
def _auto_to_explicit( self, key: Literal["ANGLES", "DIHE"], mol_dict: Mapping[str, Molecule], ) -> pd.DataFrame: if key == "ANGLES": func = get_angles elif key == "DIHE": func = get_dihedrals else: raise ValueError(key) dtype = self.DTYPES[key] assert dtype.names is not None # Computer the angles/dihedrals for all molecules array_dict = {} for res, mol in mol_dict.items(): array_dict[res] = func(mol) # Concatenate the residue-specific angles/dihedrals into a single structured array i = j = 0 total_array = np.empty(sum(len(i) for i in array_dict.values()), dtype=dtype) for res, array in array_dict.items(): j += len(array) total_array["molecule"][i:j] = res for k, field_name in enumerate(dtype.names[1:]): total_array[field_name][i:j] = array[..., k] i += len(array) # Convert the strucutred array into a dataframe df = pd.DataFrame(total_array) df.set_index("molecule", inplace=True, drop=True) return df def _to_hdf5_dict(self) -> dict[str, NDArray[np.void]]: dct: dict[str, NDArray[np.void]] = {} for name, _dtype in self.DTYPES.items(): assert _dtype.fields is not None # Construct a h5py-compatible structured dtype dtype_list = [] for sub_field, (sub_dtype, *_) in _dtype.fields.items(): if sub_dtype.kind == "U": sub_dtype = h5py.string_dtype("utf-8", sub_dtype.itemsize // 4) dtype_list.append((sub_field, sub_dtype)) dtype = np.dtype(dtype_list) df: pd.DataFrame = getattr(self, name.lower()).reset_index(inplace=False, drop=False) dct[name] = df.to_records(index=False).astype(dtype) return dct @classmethod def _get_err_msg(cls, statement: str, lst: list[tuple[Any, ...]]) -> None | str: """Construct an error message for when :meth:`~RTFContainer.from_file` fails to \ construct an array. Parameters ---------- statement : str The name of the match statement lst : list[tuple[Any, ...]] A list of tuples with structured data. The first field is guaranteed to be the residue name (a string) Returns ------- str | None A newly constructed error message or :data:`None` if one could not be constructed """ dtype = cls.DTYPES[statement] i = 0 residue_old = "" for tup in lst: residue: str = tup[0] if residue != residue_old: i = 1 else: i += 1 residue_old = residue try: np.array(tup, dtype=dtype) except Exception: return f"failed to parse {statement!r} statement {i} in residue {residue!r}" return None
[docs] @classmethod def from_file(cls, path: str | os.PathLike[str]) -> Self: """Construct a new :class:`RTFContainer` from the passed file path. Parameters ---------- path : path-like object The path to the .rtf file Returns ------- FOX.RTFContaier A newly constructed .rtf container """ dct: dict[str, list[tuple[Any, ...]]] = { "ATOM": [], "BOND": [], "IMPR": [], "ANGLES": [], "DIHE": [], "MASS": [], } auto: set[str] = set() atom_dict: dict[str, int] = {} with open(path, "r", encoding="utf8") as _f: f = FileIter(_f, start=1, stripper=lambda i: i.partition("!")[0].strip()) statement = "<UNKNOWN>" try: # Skip the top-most header until the CHARMM version has been reached i = "*" while i.startswith("*"): i = next(f) version = tuple(int(j) for j in i.split()) # Parse all MASS statements i = next(f) statement = "MASS" while i.startswith("MASS"): dct["MASS"].append(tuple(i.split()[1:])) i = next(f) # Find the first RESI statement while not i.startswith("RESI"): if i.startswith("AUTO"): auto.update(i.split()[1:]) i = next(f) statement = "RESI" # Keep parsing all REST-related statements until END has been reached res_index = 1 while i != "END": # RESI-statements are not guaranteed to contain a residue name res_fields = i.split() if len(res_fields) == 2: molecule = f"RES{res_index}" else: molecule = res_fields[1] res_index += 1 j = 0 for i in f: statement, *rest = i.split() if statement == "RESI" or statement == "END": break lst = dct.get(statement) if lst is not None: if statement == "ATOM": j += 1 atom_dict[rest[0]] = j lst.append((molecule, j, *rest[1:])) else: lst.append((molecule, *(atom_dict[at] for at in rest))) except StopIteration as ex: raise ValueError( f"{!r}: failed to find a `END` statement at the end of the file" ) from ex except Exception as ex: raise ValueError( f"{!r}: failed to parse the {statement!r} statement on line {f.index!r}" ) from ex # Convert the lists into dataframes via a structured array intermediate # Numpy arrays have much better dtype control compared to pandas dataframes/series, # hence the array intermediate kwargs: dict[str, pd.DataFrame] = {} for k, v in dct.items(): try: rec_array = np.fromiter(v, dtype=cls.DTYPES[k], count=len(v)) except Exception as ex: msg = cls._get_err_msg(k, v) if msg is None: raise else: raise ValueError(f"{!r}: {msg}") from ex df = pd.DataFrame(rec_array) df.set_index("molecule" if k != "MASS" else "index", drop=True, inplace=True) kwargs[k.lower()] = df return cls(charmm_version=version, auto=auto, **kwargs)
[docs] def concatenate(self, rtf_iter: Iterable[RTFContainer]) -> Self: """Concatenate multiple RTFContainers into a single instance. Parameters ---------- prm_iter : list[FOX.RTFContainer] A list with other RTFContainers to concatenate Returns ------- FOX.PRMContainer The new concatenated RTFContainer """ rtf_list: list[RTFContainer] = [] for rtf in rtf_iter: if not isinstance(rtf, RTFContainer): raise TypeError("Expected a RTFContainer") rtf.auto_to_explicit() rtf_list.append(rtf) dct = { "mass": pd.concat([self.mass] + [rtf.mass for rtf in rtf_list], ignore_index=True), "atom": pd.concat([self.atom] + [rtf.atom for rtf in rtf_list]), "bond": pd.concat([] + [ for rtf in rtf_list]), "impr": pd.concat([self.impropers] + [rtf.impropers for rtf in rtf_list]), "angles": pd.concat([self.angles] + [rtf.angles for rtf in rtf_list]), "dihe": pd.concat([self.dihedrals] + [rtf.dihedrals for rtf in rtf_list]), } dct["mass"].drop_duplicates("atom_type", inplace=True, ignore_index=True) cls = type(self) return cls(**dct)