diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 13c9ea804..53e90890a 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -7,7 +7,7 @@ import sys import warnings from ast import literal_eval -from typing import Dict, List, Optional, Tuple, Union +from typing import Literal import h5py import matplotlib.pyplot as plt @@ -27,16 +27,17 @@ class DeeprankDataset(Dataset): - def __init__(self, # pylint: disable=too-many-arguments - hdf5_path: Union[str, List[str]], - subset: Optional[List[str]], - target: Optional[str], - task: Optional[str], - classes: Optional[Union[List[str], List[int], List[float]]], - use_tqdm: bool, - root_directory_path: str, - target_filter: Union[Dict[str, str], None], - check_integrity: bool + def __init__( # pylint: disable=too-many-arguments + self, + hdf5_path: str | list[str], + subset: list[str] | None, + target: str | None, + task: str | None, + classes: list[str] | list[int] | list[float] | None, + use_tqdm: bool, + root_directory_path: str, + target_filter: dict[str, str] | None, + check_integrity: bool ): """Parent class of :class:`GridDataset` and :class:`GraphDataset` which inherits from :class:`torch_geometric.data.dataset.Dataset`. @@ -97,7 +98,7 @@ def _check_hdf5_files(self): for hdf5_path in to_be_removed: self.hdf5_paths.remove(hdf5_path) - def _check_task_and_classes(self, task: str, classes: Optional[str] = None): + def _check_task_and_classes(self, task: str, classes: str | None = None): if self.target in [targets.IRMSD, targets.LRMSD, targets.FNAT, targets.DOCKQ]: self.task = targets.REGRESS @@ -132,14 +133,14 @@ def _check_task_and_classes(self, task: str, classes: Optional[str] = None): def _check_inherited_params( self, - inherited_params: List[str], - dataset_train: Union[GraphDataset, GridDataset], + inherited_params: list[str], + dataset_train: GraphDataset | GridDataset, ): """"Check if the parameters for validation and/or testing are the same as in the training set. Args: - inherited_params (List[str]): List of parameters that need to be checked for inheritance. - dataset_train (Union[class:`GraphDataset`, class:`GridDataset`]): The parameters in `inherited_param` will be inherited from `dataset_train`. + inherited_params (list[str]): list of parameters that need to be checked for inheritance. + dataset_train (class:`GraphDataset` | class:`GridDataset`): The parameters in `inherited_param` will be inherited from `dataset_train`. """ self_vars = vars(self) @@ -302,25 +303,25 @@ def hdf5_to_pandas( # noqa: MC0001, pylint: disable=too-many-locals def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless-suppression self, - features: Union[str,List[str]], + features: str | list[str], fname: str = 'features_hist.png', - bins: Union[int,List[float],str] = 10, - figsize: Tuple = (15, 15), + bins: int | list[float] | str = 10, + figsize: tuple = (15, 15), log: bool = False ): """After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image. Args: - features (Union[str, List[str]]): Features to be plotted. + features (str | list[str]): Features to be plotted. fname (str): str or path-like or binary file-like object. Defaults to 'features_hist.png'. - bins (Union[int, List[float], str, optional]): If bins is an integer, it defines the number of equal-width bins in the range. + bins (int | list[float] | str, optional): If bins is an integer, it defines the number of equal-width bins in the range. If bins is a sequence, it defines the bin edges, including the left edge of the first bin and the right edge of the last bin; in this case, bins may be unequally spaced. All but the last (righthand-most) bin is half-open. If bins is a string, it is one of the binning strategies supported by numpy.histogram_bin_edges: 'auto', 'fd', 'doane', 'scott', 'stone', 'rice', 'sturges', or 'sqrt'. Defaults to 10. - figsize (Tuple, optional): Saved figure sizes. Defaults to (15, 15). + figsize (tuple, optional): Saved figure sizes. Defaults to (15, 15). log (bool): Whether to apply log transformation to the data indicated by the `features` parameter. Defaults to False. """ if self.df is None: @@ -409,38 +410,38 @@ def _compute_mean_std(self): class GridDataset(DeeprankDataset): def __init__( # pylint: disable=too-many-arguments self, - hdf5_path: Union[str, list], - subset: Optional[List[str]] = None, + hdf5_path: str | list, + subset: list[str] | None = None, train: bool = True, - dataset_train: Optional[GridDataset] = None, - features: Optional[Union[List[str], str]] = "all", - target: Optional[str] = None, - target_transform: Optional[bool] = False, - target_filter: Optional[Dict[str, str]] = None, - task: Optional[str] = None, - classes: Optional[Union[List[str], List[int], List[float]]] = None, - tqdm: Optional[bool] = True, - root: Optional[str] = "./", + dataset_train: GridDataset | None = None, + features: list[str] | str | Literal["all"] | None = "all", + target: str | None = None, + target_transform: bool = False, + target_filter: dict[str, str] | None = None, + task: Literal["regress", "classif"] | None = None, + classes: list[str] | list[int] | list[float] | None = None, + tqdm: bool = True, + root: str = "./", check_integrity: bool = True ): """Class to load the .HDF5 files data into grids. Args: - hdf5_path (Union[str,list]): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a List. Defaults to None. - subset (Optional[List[str]], optional): List of keys from .HDF5 file to include. Defaults to None (meaning include all). + hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. + subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). train (bool, optional): Boolean flag to determine if the instance represents the training set. If False, a dataset_train of the same class must be provided as well. The latter will be used to scale the validation/testing set according to its features values and to match the datasets' parameters. Defaults to True. - dataset_train (class:`GridDataset`, optional): If `train` is True, assign here the training set. + dataset_train (class:`GridDataset` | None, optional): If `train` is True, assign here the training set. If `train` is False and `dataset_train` is assigned, the parameters `features`, `target`, `traget_transform`, `task`, and `classes` will be inherited from `dataset_train`. Defaults to None. - features (Optional[Union[List[str], str]], optional): Consider all pre-computed features ("all") or some defined node features + features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed features ("all") or some defined node features (provide a list, example: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.gridstorage`. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to "all". - target (Optional[str], optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be + target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. @@ -448,25 +449,25 @@ def __init__( # pylint: disable=too-many-arguments numerical class indices before defining the :class:`GraphDataset` instance. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - target_transform (Optional[bool], optional): Apply a log and then a sigmoid transformation to the target (for regression only). + target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to False. - target_filter (Optional[Dict[str, str]], optional): Dictionary of type [target: cond] to filter the molecules. + target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None. - task (Optional[str], optional): 'regress' for regression or 'classif' for classification. Required if target not in + task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - classes (Optional[Union[List[str], List[int], List[float]], optional]): Define the dataset target classes in classification mode. + classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - tqdm (Optional[bool], optional): Show progress bar. + tqdm (bool, optional): Show progress bar. Defaults to True. - root (Optional[str], optional): Root directory where the dataset should be saved. + root (str, optional): Root directory where the dataset should be saved. Defaults to "./". check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. Defaults to True. @@ -619,50 +620,47 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: class GraphDataset(DeeprankDataset): def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-locals self, - hdf5_path: Union[str, List[str]], - subset: Optional[List[str]] = None, + hdf5_path: str | list, + subset: list[str] | None = None, train: bool = True, - dataset_train: Optional[GraphDataset] = None, - node_features: Optional[Union[List[str], str]] = "all", - edge_features: Optional[Union[List[str], str]] = "all", - features_transform: Optional[dict] = None, - clustering_method: Optional[str] = None, - target: Optional[str] = None, - target_transform: Optional[bool] = False, - target_filter: Optional[Dict[str, str]] = None, - task: Optional[str] = None, - classes: Optional[Union[List[str], List[int], List[float]]] = None, - tqdm: Optional[bool] = True, - root: Optional[str] = "./", - check_integrity: bool = True, + dataset_train: GridDataset | None = None, + node_features: list[str] | str | Literal["all"] | None = "all", + edge_features: list[str] | str | Literal["all"] | None = "all", + features_transform: dict | None = None, + clustering_method: str | None = None, + target: str | None = None, + target_transform: bool = False, + target_filter: dict[str, str] | None = None, + task: Literal["regress", "classif"] | None = None, + classes: list[str] | list[int] | list[float] | None = None, + tqdm: bool = True, + root: str = "./", + check_integrity: bool = True ): """Class to load the .HDF5 files data into graphs. Args: - hdf5_path (Union[str, List[str]]): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a List. - Defaults to None. - subset (Optional[List[str]], optional): List of keys from .HDF5 file to include. - Defaults to None (meaning include all). + hdf5_path (str | list): Path to .HDF5 file(s). For multiple .HDF5 files, insert the paths in a list. Defaults to None. + subset (list[str] | None, optional): list of keys from .HDF5 file to include. Defaults to None (meaning include all). train (bool, optional): Boolean flag to determine if the instance represents the training set. If False, a dataset_train of the same class must be provided as well. The latter will be used to scale the validation/testing set according to its features values and to match the datasets' parameters. Defaults to True. - dataset_train (class:`GraphDataset`, optional): If `train` is True, assign here the training set. + dataset_train (class:`GridDataset` | None, optional): If `train` is True, assign here the training set. If `train` is False and `dataset_train` is assigned, - the parameters `node_features`, `edge_features`, `features_transform`, `target`, - `target_transform`, `task` and `classes` will be inherited from `dataset_train`. + the parameters `features`, `target`, `traget_transform`, `task`, and `classes` will be inherited from `dataset_train`. Defaults to None. - node_features (Optional[Union[List[str], str], optional): Consider all pre-computed node features ("all") or + node_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed node features ("all") or some defined node features (provide a list, example: ["res_type", "polarity", "bsa"]). The complete list can be found in `deeprank2.domain.nodestorage`. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to "all". - edge_features (Optional[Union[List[str], str], optional): Consider all pre-computed edge features ("all") or + edge_features (list[str] | str | Literal["all"] | None, optional): Consider all pre-computed edge features ("all") or some defined edge features (provide a list, example: ["dist", "coulomb"]). The complete list can be found in `deeprank2.domain.edgestorage`. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to "all". - features_transform (Optional[dict], optional): Dictionary to indicate the transformations to apply to each feature in the dictionary, being the + features_transform (dict | None, optional): Dictionary to indicate the transformations to apply to each feature in the dictionary, being the transformations lambda functions and/or standardization. Example: `features_transform = {'bsa': {'transform': lambda t:np.log(t+1),' standardize': True}}` for the feature `bsa`. An `all` key can be set in the dictionary for indicating to apply the same `standardize` and `transform` to all the features. @@ -670,7 +668,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local If both `all` and feature name/s are present, the latter have the priority over what indicated in `all`. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - clustering_method (Optional[str], optional): "mcl" for Markov cluster algorithm (see https://micans.org/mcl/), + clustering_method (str | None, optional): "mcl" for Markov cluster algorithm (see https://micans.org/mcl/), or "louvain" for Louvain method (see https://en.wikipedia.org/wiki/Louvain_method). In both options, for each graph, the chosen method first finds communities (clusters) of nodes and generates a torch tensor whose elements represent the cluster to which the node belongs to. Each tensor is then saved in @@ -679,7 +677,7 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local The latter tensor is saved into the .HDF5 file as a :class:`Dataset` called "depth_1". Both "depth_0" and "depth_1" :class:`Datasets` belong to the "cluster" Group. They are saved in the .HDF5 file to make them available to networks that make use of clustering methods. Defaults to None. - target (Optional[str], optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. + target (str | None, optional): Default options are irmsd, lrmsd, fnat, binary, capri_class, and dockq. It can also be a custom-defined target given to the Query class as input (see: `deeprank2.query`); in this case, the task parameter needs to be explicitly specified as well. Only numerical target variables are supported, not categorical. @@ -687,24 +685,26 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local numerical class indices before defining the :class:`GraphDataset` instance. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - target_transform (Optional[bool], optional): Apply a log and then a sigmoid transformation to the target (for regression only). + target_transform (bool, optional): Apply a log and then a sigmoid transformation to the target (for regression only). This puts the target value between 0 and 1, and can result in a more uniform target distribution and speed up the optimization. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to False. - target_filter (Optional[Dict[str, str]], optional): Dictionary of type [target: cond] to filter the molecules. + target_filter (dict[str, str] | None, optional): Dictionary of type [target: cond] to filter the molecules. Note that the you can filter on a different target than the one selected as the dataset target. Defaults to None. - task (Optional[str], optional): 'regress' for regression or 'classif' for classification. Required if target not in + task (Literal["regress", "classif"] | None, optional): 'regress' for regression or 'classif' for classification. Required if target not in ['irmsd', 'lrmsd', 'fnat', 'binary', 'capri_class', or 'dockq'], otherwise this setting is ignored. Automatically set to 'classif' if the target is 'binary' or 'capri_classes'. Automatically set to 'regress' if the target is 'irmsd', 'lrmsd', 'fnat', or 'dockq'. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - classes (Optional[Union[List[str], List[int], List[float]]], optional): Define the dataset target classes in classification mode. + classes (list[str] | list[int] | list[float] | None): Define the dataset target classes in classification mode. Value will be ignored and inherited from `dataset_train` if `train` is set as False and `dataset_train` is assigned. Defaults to None. - tqdm (Optional[bool], optional): Show progress bar. Defaults to True. - root (Optional[str], optional): Root directory where the dataset should be saved. Defaults to "./". + tqdm (bool, optional): Show progress bar. + Defaults to True. + root (str, optional): Root directory where the dataset should be saved. + Defaults to "./". check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. Defaults to True. """ @@ -1014,7 +1014,7 @@ def _check_features(self): #pylint: disable=too-many-branches def save_hdf5_keys( f_src_path: str, - src_ids: List[str], + src_ids: list[str], f_dest_path: str, hardcopy = False ): @@ -1022,7 +1022,7 @@ def save_hdf5_keys( Args: f_src_path (str): The path to the .HDF5 file containing the keys. - src_ids (List[str]): Keys to be saved in the new .HDF5 file. It should be a list containing at least one key. + src_ids (list[str]): Keys to be saved in the new .HDF5 file. It should be a list containing at least one key. f_dest_path (str): The path to the new .HDF5 file. hardcopy (bool, optional): If False, the new file contains only references (external links, see :class:`ExternalLink` class from `h5py`) to the original .HDF5 file. diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index b57831290..c5727bf3c 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -1,5 +1,3 @@ -from typing import Optional - from deeprank2.molstruct.aminoacid import AminoAcid, Polarity # All info below sourced from above websites in December 2022 and summarized in deeprank2/domain/aminoacid_summary.xlsx @@ -352,7 +350,7 @@ # pyrrolysine, ] -def convert_aa_nomenclature(aa: str, output_type: Optional[int] = None): +def convert_aa_nomenclature(aa: str, output_type: int | None = None): try: if len(aa) == 1: aa: AminoAcid = [entry for entry in amino_acids if entry.one_letter_code.lower() == aa.lower()][0] diff --git a/deeprank2/features/components.py b/deeprank2/features/components.py index e7cb90d2f..73a77b76b 100644 --- a/deeprank2/features/components.py +++ b/deeprank2/features/components.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import numpy as np @@ -12,9 +11,10 @@ _log = logging.getLogger(__name__) def add_features( # pylint: disable=unused-argument - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): for node in graph.nodes: if isinstance(node.id, Residue): diff --git a/deeprank2/features/conservation.py b/deeprank2/features/conservation.py index 4a2c46f43..8ab5ad56e 100644 --- a/deeprank2/features/conservation.py +++ b/deeprank2/features/conservation.py @@ -1,4 +1,3 @@ -from typing import Optional import numpy as np @@ -10,9 +9,10 @@ def add_features( # pylint: disable=unused-argument - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): profile_amino_acid_order = sorted(amino_acids, key=lambda aa: aa.three_letter_code) diff --git a/deeprank2/features/contact.py b/deeprank2/features/contact.py index e8eaa2e83..31436bd2d 100644 --- a/deeprank2/features/contact.py +++ b/deeprank2/features/contact.py @@ -1,9 +1,8 @@ import logging import warnings -from typing import List, Optional, Tuple import numpy as np -import numpy.typing as npt +from numpy.typing import NDArray from scipy.spatial import distance_matrix from deeprank2.domain import edgestorage as Efeat @@ -21,21 +20,21 @@ cutoff_14 = 4.2 def _get_nonbonded_energy( #pylint: disable=too-many-locals - atoms: List[Atom], - distances: npt.NDArray[np.float64], - ) -> Tuple [npt.NDArray[np.float64], npt.NDArray[np.float64]]: + atoms: list[Atom], + distances: NDArray[np.float64], + ) -> tuple [NDArray[np.float64], NDArray[np.float64]]: """Calculates all pairwise electrostatic (Coulomb) and Van der Waals (Lennard Jones) potential energies between all atoms in the structure. Warning: there's no distance cutoff here. The radius of influence is assumed to infinite. However, the potential tends to 0 at large distance. Args: - atoms (List[Atom]): list of all atoms in the structure - distances (npt.NDArray[np.float64]): matrix of pairwise distances between all atoms in the structure + atoms (list[Atom]): list of all atoms in the structure + distances (NDArray[np.float64]): matrix of pairwise distances between all atoms in the structure in the format that is the output of scipy.spatial's distance_matrix (i.e. a diagonally symmetric matrix) Returns: - Tuple [npt.NDArray[np.float64], npt.NDArray[np.float64]]: matrices in same format as `distances` containing + Tuple [NDArray[np.float64], NDArray[np.float64]]: matrices in same format as `distances` containing all pairwise electrostatic potential energies and all pairwise Van der Waals potential energies """ @@ -76,9 +75,10 @@ def _get_nonbonded_energy( #pylint: disable=too-many-locals def add_features( # pylint: disable=unused-argument, too-many-locals - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): # assign each atoms (from all edges) a unique index all_atoms = set() diff --git a/deeprank2/features/exposure.py b/deeprank2/features/exposure.py index 07aed2fc9..9b86423e6 100644 --- a/deeprank2/features/exposure.py +++ b/deeprank2/features/exposure.py @@ -2,7 +2,6 @@ import signal import sys import warnings -from typing import Optional import numpy as np from Bio.PDB.Atom import PDBConstructionWarning @@ -34,9 +33,10 @@ def space_if_none(value): def add_features( # pylint: disable=unused-argument - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGALRM, handle_timeout) diff --git a/deeprank2/features/irc.py b/deeprank2/features/irc.py index f28215885..acaf40c17 100644 --- a/deeprank2/features/irc.py +++ b/deeprank2/features/irc.py @@ -1,6 +1,5 @@ import logging from itertools import combinations_with_replacement as combinations -from typing import Dict, List, Optional, Tuple import pdb2sql @@ -14,7 +13,7 @@ _log = logging.getLogger(__name__) -def _id_from_residue(residue: Tuple[str, int, str]) -> str: +def _id_from_residue(residue: tuple[str, int, str]) -> str: """Create and id from pdb2sql rendered residues that is similar to the id of residue nodes Args: @@ -32,7 +31,7 @@ class _ContactDensity: """Internal class that holds contact density information for a given residue. """ - def __init__(self, residue: Tuple[str, int, str], polarity: Polarity): + def __init__(self, residue: tuple[str, int, str], polarity: Polarity): self.res = residue self.polarity = polarity self.id = _id_from_residue(self.res) @@ -42,12 +41,12 @@ def __init__(self, residue: Tuple[str, int, str], polarity: Polarity): self.connections['all'] = [] -def get_IRCs(pdb_path: str, chains: List[str], cutoff: float = 5.5) -> Dict[str, _ContactDensity]: +def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, _ContactDensity]: """Get all close contact residues from the opposite chain. Args: pdb_path (str): Path to pdb file to read molecular information from. - chains (Sequence[str]): List (or list-like object) containing strings of the chains to be considered. + chains (Sequence[str]): list (or list-like object) containing strings of the chains to be considered. cutoff (float, optional): Cutoff distance (in Ångström) to be considered a close contact. Defaults to 10. Returns: @@ -56,7 +55,7 @@ def get_IRCs(pdb_path: str, chains: List[str], cutoff: float = 5.5) -> Dict[str, items: _ContactDensity objects, containing all contact density information for the residue. """ - residue_contacts: Dict[str, _ContactDensity] = {} + residue_contacts: dict[str, _ContactDensity] = {} sql = pdb2sql.interface(pdb_path) pdb2sql_contacts = sql.get_contact_residues( @@ -104,11 +103,12 @@ def get_IRCs(pdb_path: str, chains: List[str], cutoff: float = 5.5) -> Dict[str, def add_features( - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): - if not single_amino_acid_variant: # VariantQueries do not use this feature + if not single_amino_acid_variant: # VariantQueries do not use this feature polarity_pairs = list(combinations(Polarity, 2)) polarity_pair_string = [f'irc_{x[0].name.lower()}_{x[1].name.lower()}' for x in polarity_pairs] diff --git a/deeprank2/features/secondary_structure.py b/deeprank2/features/secondary_structure.py index d6b4a258d..c893f9cfd 100644 --- a/deeprank2/features/secondary_structure.py +++ b/deeprank2/features/secondary_structure.py @@ -1,6 +1,5 @@ from enum import Enum from pathlib import Path -from typing import Dict, List, Optional import numpy as np from Bio.PDB import PDBParser @@ -31,7 +30,7 @@ def onehot(self): return t -def _get_records(lines: List[str]): +def _get_records(lines: list[str]): seen = set() seen_add = seen.add return [x.split()[0] for x in lines if not (x in seen or seen_add(x))] @@ -82,7 +81,7 @@ def _classify_secstructure(subtype: str): return None -def _get_secstructure(pdb_path: str) -> Dict: +def _get_secstructure(pdb_path: str) -> dict: """Process the DSSP output to extract secondary structure information. Args: @@ -124,8 +123,8 @@ def _get_secstructure(pdb_path: str) -> Dict: def add_features( # pylint: disable=unused-argument pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + single_amino_acid_variant: SingleResidueVariant | None = None, +): sec_structure_features = _get_secstructure(pdb_path) diff --git a/deeprank2/features/surfacearea.py b/deeprank2/features/surfacearea.py index d36ffe0bc..022bbccf7 100644 --- a/deeprank2/features/surfacearea.py +++ b/deeprank2/features/surfacearea.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import freesasa import numpy as np @@ -81,7 +80,7 @@ def add_bsa(graph: Graph): sasa_complete_result = freesasa.calc(sasa_complete_structure) sasa_chain_results = {chain_id: freesasa.calc(structure) - for chain_id, structure in sasa_chain_structures.items()} + for chain_id, structure in sasa_chain_structures.items()} for node in graph.nodes: if isinstance(node.id, Residue): @@ -95,7 +94,7 @@ def add_bsa(graph: Graph): chain_id = atom.residue.chain.id area_key = "atom" selection = ("atom, (name %s) and (resi %s) and (chain %s)" % \ - (atom.name, atom.residue.number_string, atom.residue.chain.id),) # pylint: disable=consider-using-f-string + (atom.name, atom.residue.number_string, atom.residue.chain.id),) # pylint: disable=consider-using-f-string area_monomer = freesasa.selectArea(selection, sasa_chain_structures[chain_id], \ sasa_chain_results[chain_id])[area_key] @@ -105,9 +104,10 @@ def add_bsa(graph: Graph): def add_features( # pylint: disable=unused-argument - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None, +): """calculates the Buried Surface Area (BSA) and the Solvent Accessible Surface Area (SASA): BSA: the area of the protein, that only gets exposed in monomeric state""" diff --git a/deeprank2/molstruct/atom.py b/deeprank2/molstruct/atom.py index 56ee7e10e..6bdef04de 100644 --- a/deeprank2/molstruct/atom.py +++ b/deeprank2/molstruct/atom.py @@ -3,6 +3,7 @@ from enum import Enum import numpy as np +from numpy.typing import NDArray from deeprank2.molstruct.residue import Residue @@ -31,7 +32,7 @@ def __init__( # pylint: disable=too-many-arguments residue: Residue, name: str, element: AtomicElement, - position: np.array, + position: NDArray, occupancy: float, ): """ @@ -81,7 +82,7 @@ def occupancy(self) -> float: return self._occupancy @property - def position(self) -> np.array: + def position(self) -> NDArray: return self._position @property diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index ef8700532..e9938b219 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np +from numpy.typing import NDArray from deeprank2.molstruct.aminoacid import AminoAcid from deeprank2.molstruct.structure import Chain @@ -25,8 +26,8 @@ def __init__( self, chain: Chain, number: int, - amino_acid: Optional[AminoAcid] = None, - insertion_code: Optional[str] = None, + amino_acid: AminoAcid | None = None, + insertion_code: str | None = None, ): """ Args: @@ -98,7 +99,7 @@ def __repr__(self) -> str: def position(self) -> np.array: return np.mean([atom.position for atom in self._atoms], axis=0) - def get_center(self) -> np.ndarray: + def get_center(self) -> NDArray: """Find the center position of a `Residue`. Center position is found as follows: diff --git a/deeprank2/molstruct/structure.py b/deeprank2/molstruct/structure.py index 4b508ab04..f1f5e93d7 100644 --- a/deeprank2/molstruct/structure.py +++ b/deeprank2/molstruct/structure.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from deeprank2.utils.pssmdata import PssmRow @@ -18,7 +18,7 @@ class PDBStructure: particular `AminoAcid` type and in turn consists of a number of `Atom`s. """ - def __init__(self, id_: Optional[str] = None): + def __init__(self, id_: str | None = None): """ Args: id_ (str, optional): An unique identifier for this structure, can be the pdb accession code. @@ -72,7 +72,7 @@ class Chain: In other words: each `Chain` in a `PDBStructure` is a separate molecule. """ - def __init__(self, model: PDBStructure, id_: Optional[str]): + def __init__(self, model: PDBStructure, id_: str | None): """One chain of a PDBStructure. Args: @@ -99,10 +99,10 @@ def pssm(self, pssm: PssmRow): def add_residue(self, residue: Residue): self._residues[(residue.number, residue.insertion_code)] = residue - def has_residue(self, residue_number: int, insertion_code: Optional[str] = None) -> bool: + def has_residue(self, residue_number: int, insertion_code: str | None = None) -> bool: return (residue_number, insertion_code) in self._residues - def get_residue(self, residue_number: int, insertion_code: Optional[str] = None) -> Residue: + def get_residue(self, residue_number: int, insertion_code: str | None = None) -> Residue: return self._residues[(residue_number, insertion_code)] @property diff --git a/deeprank2/neuralnets/cnn/model3d.py b/deeprank2/neuralnets/cnn/model3d.py index c8441b4c3..cbae80b4d 100644 --- a/deeprank2/neuralnets/cnn/model3d.py +++ b/deeprank2/neuralnets/cnn/model3d.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn import torch.nn.functional as F @@ -25,7 +23,7 @@ class CnnRegression(torch.nn.Module): - def __init__(self, num_features: int, box_shape: Tuple[int]): + def __init__(self, num_features: int, box_shape: tuple[int]): super().__init__() self.convlayer_000 = torch.nn.Conv3d(num_features, 4, kernel_size=2) @@ -38,7 +36,7 @@ def __init__(self, num_features: int, box_shape: Tuple[int]): self.fclayer_000 = torch.nn.Linear(size, 84) self.fclayer_001 = torch.nn.Linear(84, 1) - def _get_conv_output(self, num_features: int, shape: Tuple[int]): + def _get_conv_output(self, num_features: int, shape: tuple[int]): num_data_points = 2 input_ = Variable(torch.rand(num_data_points, num_features, *shape)) output = self._forward_features(input_) diff --git a/deeprank2/tools/target.py b/deeprank2/tools/target.py index 80f57be9f..52e365220 100644 --- a/deeprank2/tools/target.py +++ b/deeprank2/tools/target.py @@ -1,6 +1,5 @@ import glob import os -from typing import Dict, List, Union import h5py import numpy as np @@ -9,11 +8,16 @@ from deeprank2.domain import targetstorage as targets -def add_target(graph_path: Union[str, List[str]], target_name: str, target_list: str, sep: str = " "): +def add_target( + graph_path: str | list[str], + target_name: str, + target_list: str, + sep: str = " ", +): """Add a target to all the graphs in hdf5 files. Args: - graph_path (Union[str, List(str)]): Either a directory containing all the hdf5 files, + graph_path (str | list(str)): Either a directory containing all the hdf5 files, or a single hdf5 filename or a list of hdf5 filenames. target_name (str): The name of the new target. @@ -83,7 +87,7 @@ def add_target(graph_path: Union[str, List[str]], target_name: str, target_list: print(f"no graph for {hdf5}") -def compute_ppi_scores(pdb_path: str, reference_pdb_path: str) -> Dict[str, Union[float, int]]: +def compute_ppi_scores(pdb_path: str, reference_pdb_path: str) -> dict[str, float | int]: """Compute structure similarity scores for the input docking model and return them as a dictionary. diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 3e9f14933..cffcb0eec 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -1,7 +1,6 @@ import copy import logging from time import time -from typing import List, Optional, Tuple, Union import h5py import numpy as np @@ -27,16 +26,16 @@ class Trainer(): def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 self, neuralnet = None, - dataset_train: Optional[Union[GraphDataset, GridDataset]] = None, - dataset_val: Optional[Union[GraphDataset, GridDataset]] = None, - dataset_test: Optional[Union[GraphDataset, GridDataset]] = None, - val_size: Optional[Union[float, int]] = None, - test_size: Optional[Union[float, int]] = None, + dataset_train: GraphDataset | GridDataset | None = None, + dataset_val: GraphDataset | GridDataset | None = None, + dataset_test: GraphDataset | GridDataset | None = None, + val_size: float | int | None = None, + test_size: float | int | None = None, class_weights: bool = False, - pretrained_model: Optional[str] = None, + pretrained_model: str | None = None, cuda: bool = False, ngpu: int = 0, - output_exporters: Optional[List[OutputExporter]] = None, + output_exporters: list[OutputExporter] | None = None, ): """Class from which the network is trained, evaluated and tested. @@ -46,24 +45,25 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 in terms of output shape (:class:`Trainer` class takes care of formatting the output shape according to the task). More specifically, in classification task cases, softmax shouldn't be used as the last activation function. Defaults to None. - dataset_train (Optional[Union[:class:`GraphDataset`, :class:`GridDataset`]], optional): Training set used during training. + dataset_train (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Training set used during training. Can't be None if pretrained_model is also None. Defaults to None. - dataset_val (Optional[Union[:class:`GraphDataset`, :class:`GridDataset`]], optional): Evaluation set used during training. + dataset_val (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Evaluation set used during training. If None, training set will be split randomly into training set and validation set during training, using val_size parameter. Defaults to None. - dataset_test (Optional[Union[:class:`GraphDataset`, :class:`GridDataset`]], optional): Independent evaluation set. Defaults to None. - val_size (Optional[Union[float,int]], optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. + dataset_test (:class:`GraphDataset` | :class:`GridDataset` | None, optional): Independent evaluation set. Defaults to None. + val_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. Only used if dataset_val is not specified. Can be set to 0 if no validation set is needed. Defaults to None (in _divide_dataset function). - test_size (Optional[Union[float,int]], optional): Fraction of dataset (if float) or number of datapoints (if int) to use for test dataset. + test_size (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for test dataset. Only used if dataset_test is not specified. Can be set to 0 if no test set is needed. Defaults to None. class_weights (bool, optional): Assign class weights based on the dataset content. Defaults to False. - pretrained_model (Optional[str], optional): Path to pre-trained model. Defaults to None. + pretrained_model (str | None, optional): Path to pre-trained model. Defaults to None. cuda (bool, optional): Whether to use CUDA. Defaults to False. ngpu (int, optional): Number of GPU to be used. Defaults to 0. - output_exporters (Optional[List[OutputExporter]], optional): The output exporters to use for saving/exploring/plotting predictions/targets/losses + output_exporters (list[OutputExporter] | None, optional): The output exporters to use for saving/exploring/plotting predictions/targets/losses over the epochs. If None, defaults to :class:`HDF5OutputExporter`, which saves all the results in an .HDF5 file stored in ./output directory. Defaults to None. """ + self.batch_size_train = None self.batch_size_test = None self.shuffle = None @@ -175,19 +175,20 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 self._load_params() self._load_pretrained_model() - def _init_output_exporters(self, output_exporters: Optional[List[OutputExporter]]): - + def _init_output_exporters(self, output_exporters: list[OutputExporter] | None): if output_exporters is not None: self._output_exporters = OutputExporterCollection(*output_exporters) else: self._output_exporters = OutputExporterCollection(HDF5OutputExporter('./output')) - def _init_datasets(self, # pylint: disable=too-many-arguments - dataset_train: Union[GraphDataset, GridDataset], - dataset_val: Optional[Union[GraphDataset, GridDataset]], - dataset_test: Optional[Union[GraphDataset, GridDataset]], - val_size: Optional[Union[int, float]], - test_size: Optional[Union[int, float]]): + def _init_datasets( # pylint: disable=too-many-arguments + self, + dataset_train: GraphDataset | GridDataset, + dataset_val: GraphDataset | GridDataset | None, + dataset_test: GraphDataset | GridDataset | None, + val_size: int | float | None, + test_size: int | float | None, + ): self._check_dataset_equivalence(dataset_train, dataset_val, dataset_test) @@ -216,7 +217,7 @@ def _init_datasets(self, # pylint: disable=too-many-arguments else: self._init_from_dataset(self.dataset_test) - def _init_from_dataset(self, dataset: Union[GraphDataset, GridDataset]): + def _init_from_dataset(self, dataset: GraphDataset | GridDataset): if isinstance(dataset, GraphDataset): self.clustering_method = dataset.clustering_method @@ -328,12 +329,12 @@ def _precluster(self, dataset: GraphDataset): f5.close() - def _put_model_to_device(self, dataset: Union[GraphDataset, GridDataset]): + def _put_model_to_device(self, dataset: GraphDataset | GridDataset): """ Puts the model on the available device Args: - dataset (Union[:class:`GraphDataset`, :class:`GridDataset`]): GraphDataset object. + dataset (:class:`GraphDataset` | :class:`GridDataset`): GraphDataset object. Raises: ValueError: Incorrect output shape @@ -490,13 +491,13 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc nepoch: int = 1, batch_size: int = 32, shuffle: bool = True, - earlystop_patience: Optional[int] = None, - earlystop_maxgap: Optional[float] = None, + earlystop_patience: int | None = None, + earlystop_maxgap: float | None = None, min_epoch: int = 10, validate: bool = False, num_workers: int = 0, best_model: bool = True, - filename: Optional[str] = 'model.pth.tar' + filename: str | None = 'model.pth.tar' ): """ Performs the training of the model. @@ -508,9 +509,9 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc Defaults to 32. shuffle (bool, optional): Whether to shuffle the training dataloaders data (train set and validation set). Default: True. - earlystop_patience (Optional[int], optional): Training ends if the model has run for this number of epochs without improving the validation loss. + earlystop_patience (int | None, optional): Training ends if the model has run for this number of epochs without improving the validation loss. Defaults to None. - earlystop_maxgap (Optional[float], optional): Training ends if the difference between validation and training loss exceeds this value. + earlystop_maxgap (float | None, optional): Training ends if the difference between validation and training loss exceeds this value. Defaults to None. min_epoch (float, optional): Minimum epoch to be reached before looking at maxgap. Defaults to 10. @@ -908,14 +909,16 @@ def _save_model(self): return state -def _divide_dataset(dataset: Union[GraphDataset, GridDataset], splitsize: Optional[Union[float, int]] = None) -> \ - Union[Tuple[GraphDataset, GraphDataset], Tuple[GridDataset, GridDataset]]: +def _divide_dataset( + dataset: GraphDataset | GridDataset, + splitsize: float | int | None = None, +) -> tuple[GraphDataset, GraphDataset] | tuple[GridDataset, GridDataset]: """Divides the dataset into a training set and an evaluation set Args: - dataset (Union[:class:`GraphDataset`, :class:`GridDataset`]): Input dataset to be split into training and validation data. - splitsize (Optional[Union[float, int]], optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. + dataset (:class:`GraphDataset` | :class:`GridDataset`): Input dataset to be split into training and validation data. + splitsize (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. Defaults to None. """ diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index 9f952be8a..88f4f7124 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -1,6 +1,5 @@ import logging import os -from typing import List import numpy as np from pdb2sql import interface as get_interface @@ -155,9 +154,9 @@ def get_structure(pdb, id_: str) -> PDBStructure: def get_contact_atoms( # pylint: disable=too-many-locals pdb_path: str, - chain_ids: List[str], + chain_ids: list[str], interaction_radius: float -) -> List[Atom]: +) -> list[Atom]: """Gets the contact atoms from pdb2sql and wraps them in python objects.""" interface = get_interface(pdb_path) @@ -210,7 +209,7 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals chain_id1: str, chain_id2: str, interaction_radius: float, -) -> List[Pair]: +) -> list[Pair]: """Find all residue pairs that may influence each other. Args: @@ -221,7 +220,7 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals interaction_radius (float): Maximum distance between residues to consider them as interacting. Returns: - List[Pair]: The pairs of contacting residues. + list[Pair]: The pairs of contacting residues. """ # Find out which residues are pairs diff --git a/deeprank2/utils/earlystopping.py b/deeprank2/utils/earlystopping.py index eeedb1bfa..aa6acbaf5 100644 --- a/deeprank2/utils/earlystopping.py +++ b/deeprank2/utils/earlystopping.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from typing import Callable class EarlyStopping: @@ -6,7 +6,7 @@ def __init__( # pylint: disable=too-many-arguments self, patience: int = 10, delta: float = 0, - maxgap: Optional[float] = None, + maxgap: float | None = None, min_epoch: int = 10, verbose: bool = True, trace_func: Callable = print, diff --git a/deeprank2/utils/exporters.py b/deeprank2/utils/exporters.py index d792a981b..b8ce1863b 100644 --- a/deeprank2/utils/exporters.py +++ b/deeprank2/utils/exporters.py @@ -2,7 +2,6 @@ import os import random from math import sqrt -from typing import Any, Dict, List, Optional, Tuple import pandas as pd from matplotlib import pyplot @@ -35,14 +34,15 @@ def __exit__(self, exception_type, exception, traceback): pass # pylint: disable=unnecessary-pass def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: List[str], output_values: List[Any], target_values: List[Any], loss: float): + entry_names: list[str], output_values: list, target_values: list, loss: float): "the entry_names, output_values, target_values MUST have the same length" pass # pylint: disable=unnecessary-pass def is_compatible_with( # pylint: disable=unused-argument self, output_data_shape: int, - target_data_shape: Optional[int] = None) -> bool: + target_data_shape: int | None = None, + ) -> bool: "true if this exporter can work with the given data shapes" return True @@ -50,7 +50,7 @@ def is_compatible_with( # pylint: disable=unused-argument class OutputExporterCollection: """It allows a series of output exporters to be used at the same time.""" - def __init__(self, *args: List[OutputExporter]): + def __init__(self, *args: list[OutputExporter]): self._output_exporters = args def __enter__(self): @@ -64,7 +64,7 @@ def __exit__(self, exception_type, exception, traceback): output_exporter.__exit__(exception_type, exception, traceback) def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: List[str], output_values: List[Any], target_values: List[Any], loss: float): + entry_names: list[str], output_values: list, target_values: list, loss: float): for output_exporter in self._output_exporters: output_exporter.process(pass_name, epoch_number, entry_names, output_values, target_values, loss) @@ -94,7 +94,7 @@ def __exit__(self, exception_type, exception, traceback): self._writer.__exit__(exception_type, exception, traceback) def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments, too-many-locals - entry_names: List[str], output_values: List[Any], target_values: List[Any], loss: float): + entry_names: list[str], output_values: list, target_values: list, loss: float): "write to tensorboard" ce_loss = cross_entropy(tensor(output_values), tensor(target_values)).item() @@ -139,7 +139,11 @@ def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many- roc_auc = roc_auc_score(target_values, probabilities) self._writer.add_scalar(f"{pass_name} ROC AUC", roc_auc, epoch_number) - def is_compatible_with(self, output_data_shape: int, target_data_shape: Optional[int] = None) -> bool: + def is_compatible_with( + self, + output_data_shape: int, + target_data_shape: int | None = None, + ) -> bool: """For regression, target data is needed and output data must be a list of two-dimensional values.""" return output_data_shape == 2 and target_data_shape == 1 @@ -189,7 +193,7 @@ def _get_color(pass_name): return random.choice(["yellow", "cyan", "magenta"]) @staticmethod - def _plot(epoch_number: int, data: Dict[str, Tuple[List[float], List[float]]], png_path: str): + def _plot(epoch_number: int, data: dict[str, tuple[list[float], list[float]]], png_path: str): pyplot.title(f"Epoch {epoch_number}") @@ -204,7 +208,7 @@ def _plot(epoch_number: int, data: Dict[str, Tuple[List[float], List[float]]], p pyplot.close() def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: List[str], output_values: List[Any], target_values: List[Any], loss: float): + entry_names: list[str], output_values: list, target_values: list, loss: float): """Make the plot, if the epoch matches with the interval.""" if epoch_number % self._epoch_interval == 0: @@ -217,7 +221,11 @@ def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many- path = self.get_filename(epoch_number) self._plot(epoch_number, self._plot_data[epoch_number], path) - def is_compatible_with(self, output_data_shape: int, target_data_shape: Optional[int] = None) -> bool: + def is_compatible_with( + self, + output_data_shape: int, + target_data_shape: int | None = None, + ) -> bool: """For regression, target data is needed and output data must be a list of one-dimensional values.""" return output_data_shape == 1 and target_data_shape == 1 @@ -266,9 +274,9 @@ def process( # pylint: disable=too-many-arguments self, pass_name: str, epoch_number: int, - entry_names: List[str], - output_values: List[Any], - target_values: List[Any], + entry_names: list[str], + output_values: list, + target_values: list, loss: float): self.phase = pass_name diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index c57706e03..174eeb41c 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -1,10 +1,11 @@ import logging import os -from typing import Callable, List, Optional, Union +from typing import Callable import h5py import numpy as np import pdb2sql.transform +from numpy.typing import NDArray from scipy.spatial import distance_matrix from deeprank2.domain import edgestorage as Efeat @@ -49,7 +50,7 @@ def has_nan(self) -> bool: class Node: - def __init__(self, id_: Union[Atom, Residue]): + def __init__(self, id_: Atom | Residue): if isinstance(id_, Atom): self._type = "atom" elif isinstance(id_, Residue): @@ -76,7 +77,7 @@ def has_nan(self) -> bool: def add_feature( self, feature_name: str, - feature_function: Callable[[Union[Atom, Residue]], np.ndarray], + feature_function: Callable[[Atom | Residue], NDArray], ): feature_value = feature_function(self.id) @@ -109,7 +110,7 @@ def __init__(self, id_: str): def add_node(self, node: Node): self._nodes[node.id] = node - def get_node(self, id_: Union[Atom, Residue]) -> Node: + def get_node(self, id_: Atom | Residue) -> Node: return self._nodes[id_] def add_edge(self, edge: Edge): @@ -119,11 +120,11 @@ def get_edge(self, id_: Contact) -> Edge: return self._edges[id_] @property - def nodes(self) -> List[Node]: + def nodes(self) -> list[Node]: return list(self._nodes.values()) @property - def edges(self) -> List[Node]: + def edges(self) -> list[Node]: return list(self._edges.values()) def has_nan(self) -> bool: @@ -140,9 +141,9 @@ def has_nan(self) -> bool: return False def _map_point_features(self, grid: Grid, method: MapMethod, # pylint: disable=too-many-arguments - feature_name: str, points: List[np.ndarray], - values: List[Union[float, np.ndarray]], - augmentation: Optional[Augmentation] = None): + feature_name: str, points: list[NDArray], + values: list[float | NDArray], + augmentation: Augmentation | None = None): points = np.stack(points, axis=0) @@ -158,7 +159,7 @@ def _map_point_features(self, grid: Grid, method: MapMethod, # pylint: disable= grid.map_feature(position, feature_name, value, method) - def map_to_grid(self, grid: Grid, method: MapMethod, augmentation: Optional[Augmentation] = None): + def map_to_grid(self, grid: Grid, method: MapMethod, augmentation: Augmentation | None = None): # order edge features by xyz point points = [] @@ -282,7 +283,7 @@ def write_as_grid_to_hdf5( self, hdf5_path: str, settings: GridSettings, method: MapMethod, - augmentation: Optional[Augmentation] = None + augmentation: Augmentation | None = None ) -> str: id_ = self.id @@ -308,7 +309,7 @@ def write_as_grid_to_hdf5( return hdf5_path - def get_all_chains(self) -> List[str]: + def get_all_chains(self) -> list[str]: if isinstance(self.nodes[0].id, Residue): chains = set(str(res.chain).split()[1] for res in [node.id for node in self.nodes]) elif isinstance(self.nodes[0].id, Atom): @@ -319,7 +320,7 @@ def get_all_chains(self) -> List[str]: def build_atomic_graph( # pylint: disable=too-many-locals - atoms: List[Atom], graph_id: str, max_edge_distance: float + atoms: list[Atom], graph_id: str, max_edge_distance: float ) -> Graph: """Builds a graph, using the atoms as nodes. @@ -354,7 +355,7 @@ def build_atomic_graph( # pylint: disable=too-many-locals def build_residue_graph( # pylint: disable=too-many-locals - residues: List[Residue], graph_id: str, max_edge_distance: float + residues: list[Residue], graph_id: str, max_edge_distance: float ) -> Graph: """Builds a graph, using the residues as nodes. diff --git a/deeprank2/utils/grid.py b/deeprank2/utils/grid.py index 648da0f0a..9a8071e86 100644 --- a/deeprank2/utils/grid.py +++ b/deeprank2/utils/grid.py @@ -3,10 +3,10 @@ import itertools import logging from enum import Enum -from typing import Dict, List, Union import h5py import numpy as np +from numpy.typing import NDArray from scipy.signal import bspline from deeprank2.domain import gridstorage @@ -29,12 +29,12 @@ class MapMethod(Enum): class Augmentation: """A rotation around an axis, to be applied to a feature before mapping it to a grid.""" - def __init__(self, axis: np.ndarray, angle: float): + def __init__(self, axis: NDArray, angle: float): self._axis = axis self._angle = angle @property - def axis(self) -> np.ndarray: + def axis(self) -> NDArray: return self._axis @property @@ -54,8 +54,8 @@ class GridSettings: def __init__( self, - points_counts: List[int], - sizes: List[float] + points_counts: list[int], + sizes: list[float] ): assert len(points_counts) == 3 assert len(sizes) == 3 @@ -64,15 +64,15 @@ def __init__( self._sizes = sizes @property - def resolutions(self) -> List[float]: + def resolutions(self) -> list[float]: return [self._sizes[i] / self._points_counts[i] for i in range(3)] @property - def sizes(self) -> List[float]: + def sizes(self) -> list[float]: return self._sizes @property - def points_counts(self) -> List[int]: + def points_counts(self) -> list[int]: return self._points_counts @@ -84,7 +84,7 @@ class Grid: - feature values on each point """ - def __init__(self, id_: str, center: List[float], settings: GridSettings): + def __init__(self, id_: str, center: list[float], settings: GridSettings): self.id = id_ self._center = np.array(center) @@ -95,7 +95,7 @@ def __init__(self, id_: str, center: List[float], settings: GridSettings): self._features = {} - def _set_mesh(self, center: np.ndarray, settings: GridSettings): + def _set_mesh(self, center: NDArray, settings: GridSettings): """Builds the grid points.""" half_size_x = settings.sizes[0] / 2 @@ -119,38 +119,38 @@ def _set_mesh(self, center: np.ndarray, settings: GridSettings): ) @property - def center(self) -> np.ndarray: + def center(self) -> NDArray: return self._center @property - def xs(self) -> np.array: + def xs(self) -> NDArray: return self._xs @property - def xgrid(self) -> np.array: + def xgrid(self) -> NDArray: return self._xgrid @property - def ys(self) -> np.array: + def ys(self) -> NDArray: return self._ys @property - def ygrid(self) -> np.array: + def ygrid(self) -> NDArray: return self._ygrid @property - def zs(self) -> np.array: + def zs(self) -> NDArray: return self._zs @property - def zgrid(self) -> np.array: + def zgrid(self) -> NDArray: return self._zgrid @property - def features(self) -> Dict[str, np.array]: + def features(self) -> dict[str, NDArray]: return self._features - def add_feature_values(self, feature_name: str, data: np.ndarray): + def add_feature_values(self, feature_name: str, data: NDArray): """Makes sure feature values per grid point get stored. This method may be called repeatedly to add on to existing grid point values. @@ -162,8 +162,10 @@ def add_feature_values(self, feature_name: str, data: np.ndarray): self._features[feature_name] += data def _get_mapped_feature_gaussian( - self, position: np.ndarray, value: float - ) -> np.ndarray: + self, + position: NDArray, + value: float + ) -> NDArray: beta = 1.0 @@ -175,8 +177,8 @@ def _get_mapped_feature_gaussian( return value * np.exp(-beta * distances) def _get_mapped_feature_fast_gaussian( - self, position: np.ndarray, value: float - ) -> np.ndarray: + self, position: NDArray, value: float + ) -> NDArray: beta = 1.0 cutoff = 5.0 * beta @@ -195,8 +197,8 @@ def _get_mapped_feature_fast_gaussian( return data def _get_mapped_feature_bsp_line( - self, position: np.ndarray, value: float - ) -> np.ndarray: + self, position: NDArray, value: float + ) -> NDArray: order = 4 @@ -210,8 +212,8 @@ def _get_mapped_feature_bsp_line( return value * bsp_data def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals - self, position: np.ndarray, value: float - ) -> np.ndarray: + self, position: NDArray, value: float + ) -> NDArray: fx, _, _ = position distances_x = np.abs(self.xs - fx) @@ -248,14 +250,14 @@ def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals return neighbour_data - def _get_atomic_density_koes(self, position: np.ndarray, vanderwaals_radius: float) -> np.ndarray: + def _get_atomic_density_koes(self, position: NDArray, vanderwaals_radius: float) -> NDArray: """Function to map individual atomic density on the grid. The formula is equation (1) of the Koes paper Protein-Ligand Scoring with Convolutional NN Arxiv:1612.02751v1. Returns: - np.ndarray: The mapped density. + NDArray: The mapped density. """ distances = np.sqrt(np.square(self.xgrid - position[0]) + @@ -276,9 +278,9 @@ def _get_atomic_density_koes(self, position: np.ndarray, vanderwaals_radius: flo def map_feature( self, - position: np.ndarray, + position: NDArray, feature_name: str, - feature_value: Union[np.ndarray, float], + feature_value: NDArray | float, method: MapMethod, ): """Maps point feature data at a given position to the grid, using the given method. diff --git a/deeprank2/utils/parsing/patch.py b/deeprank2/utils/parsing/patch.py index 147184b47..b03df3a0f 100644 --- a/deeprank2/utils/parsing/patch.py +++ b/deeprank2/utils/parsing/patch.py @@ -1,6 +1,6 @@ import re from enum import Enum -from typing import Any, Dict +from typing import Any class PatchActionType(Enum): @@ -15,7 +15,7 @@ def __init__(self, residue_type: str, atom_name: str): class PatchAction: - def __init__(self, type_: str, selection: PatchSelection, kwargs: Dict[str, Any]): + def __init__(self, type_: str, selection: PatchSelection, kwargs: dict[str, Any]): self.type = type_ self.selection = selection self.kwargs = kwargs diff --git a/deeprank2/utils/parsing/residue.py b/deeprank2/utils/parsing/residue.py index 640bf4f5d..4609029aa 100644 --- a/deeprank2/utils/parsing/residue.py +++ b/deeprank2/utils/parsing/residue.py @@ -1,14 +1,13 @@ import re -from typing import List, Union class ResidueClassCriterium: def __init__( self, class_name: str, - amino_acid_names: Union[str, List[str]], - present_atom_names: List[str], - absent_atom_names: List[str], + amino_acid_names: str | list[str], + present_atom_names: list[str], + absent_atom_names: list[str], ): self.class_name = class_name @@ -17,7 +16,7 @@ def __init__( self.present_atom_names = present_atom_names self.absent_atom_names = absent_atom_names - def matches(self, amino_acid_name: str, atom_names: List[str]) -> bool: + def matches(self, amino_acid_name: str, atom_names: list[str]) -> bool: # check the amino acid name if self.amino_acid_names != "all": diff --git a/deeprank2/utils/parsing/top.py b/deeprank2/utils/parsing/top.py index 6d77b5284..27b58d140 100644 --- a/deeprank2/utils/parsing/top.py +++ b/deeprank2/utils/parsing/top.py @@ -1,12 +1,16 @@ import logging import re -from typing import Any, Dict +from typing import Any logging.getLogger(__name__) class TopRowObject: - def __init__(self, residue_name: str, - atom_name: str, kwargs: Dict[str, Any]): + def __init__( + self, + residue_name: str, + atom_name: str, + kwargs: dict[str, Any], + ): self.residue_name = residue_name self.atom_name = atom_name self.kwargs = kwargs diff --git a/deeprank2/utils/pssmdata.py b/deeprank2/utils/pssmdata.py index 066bc93e1..b10a27ce1 100644 --- a/deeprank2/utils/pssmdata.py +++ b/deeprank2/utils/pssmdata.py @@ -1,17 +1,15 @@ -from typing import Dict, List, Optional - from deeprank2.molstruct.aminoacid import AminoAcid class PssmRow: """Holds data for one position-specific scoring matrix row.""" - def __init__(self, conservations: Dict[AminoAcid, float], information_content: float): + def __init__(self, conservations: dict[AminoAcid, float], information_content: float): self._conservations = conservations self._information_content = information_content @property - def conservations(self) -> Dict[AminoAcid, float]: + def conservations(self) -> dict[AminoAcid, float]: return self._conservations @property @@ -25,7 +23,7 @@ def get_conservation(self, amino_acid: AminoAcid) -> float: class PssmTable: """Holds data for one position-specific scoring table.""" - def __init__(self, rows: Optional[List[PssmRow]] = None): + def __init__(self, rows: list[PssmRow] | None = None): if rows is None: self._rows = {} else: diff --git a/docs/features.md b/docs/features.md index 0fac1f047..32b052bf6 100644 --- a/docs/features.md +++ b/docs/features.md @@ -8,22 +8,22 @@ Features implemented in the code-base are defined in `deeprank2.feature` subpack Users can add custom features by creating a new module and placing it in `deeprank2.feature` subpackage. One requirement for any feature module is to implement an `add_features` function, as shown below. This will be used in `deeprank2.models.query` to add the features to the nodes or edges of the graph. ```python -from typing import Optional from deeprank2.molstruct.residue import SingleResidueVariant from deeprank2.utils.graph import Graph def add_features( - pdb_path: str, graph: Graph, - single_amino_acid_variant: Optional[SingleResidueVariant] = None - ): + pdb_path: str, + graph: Graph, + single_amino_acid_variant: SingleResidueVariant | None = None +): pass ``` -The following is a brief description of the features already implemented in the code-base, for each features' module. +The following is a brief description of the features already implemented in the code-base, for each features' module. -## Default node features +## Default node features For atomic graphs, when features relate to residues then _all_ atoms of one residue receive the feature value for that residue. ### Core properties of atoms and residues: `deeprank2.features.components` @@ -34,12 +34,12 @@ These features are only used in atomic graphs. - `atom_type`: One-hot encoding of the atomic element. Options are: C, O, N, S, P, H. - `atom_charge`: Atomic charge in Coulomb (float). Taken from `deeprank2.domain.forcefield.patch.top`. -- `pdb_occupancy`: Proportion of structures where the atom was detected at this position (float). In some cases a single atom was detected at different positions, in which case separate structures exist whose occupancies sum to 1. Only the highest occupancy atom is used by deeprank2. +- `pdb_occupancy`: Proportion of structures where the atom was detected at this position (float). In some cases a single atom was detected at different positions, in which case separate structures exist whose occupancies sum to 1. Only the highest occupancy atom is used by deeprank2. #### Residue properties: - `res_type`: One-hot encoding of the amino acid residue (size 20). - `polarity`: One-hot encoding of the polarity of the amino acid (options: NONPOLAR, POLAR, NEGATIVE, POSITIVE). Note that sources vary on the polarity for few of the amino acids; see detailed information in `deeprank2.domain.aminoacidlist.py`. -- `res_size`: The number of non-hydrogen atoms in the side chain (int). +- `res_size`: The number of non-hydrogen atoms in the side chain (int). - `res_mass`: The (average) residue mass in Da (float). - `res_charge`: The charge of the residue (in fully protonated state) in Coulomb (int). Charge is calculated from summing all atoms in the residue, which results in a charge of 0 for all polar and nonpolar residues, +1 for positive residues and -1 for negative residues. - `res_pI`: The isolectric point, i.e. the pH at which the molecule has no net electric charge (float). @@ -55,10 +55,10 @@ These features are only used in SingleResidueVariant queries. ### Conservation features: `deeprank2.features.conservation` These features relate to the conservation state of individual residues. -- `pssm`: [Position-specific scoring matrix](https://en.wikipedia.org/wiki/Position_weight_matrix) (also known as position weight matrix, PWM) values relative to the residue, is a score of the conservation of the amino acid along all 20 amino acids. +- `pssm`: [Position-specific scoring matrix](https://en.wikipedia.org/wiki/Position_weight_matrix) (also known as position weight matrix, PWM) values relative to the residue, is a score of the conservation of the amino acid along all 20 amino acids. - `info_content`: Information content is the difference between the given PSSM for an amino acid and a uniform distribution (float). - `conservation` (only used in SingleResidueVariant queries): Conservation of the wild type amino acid (float). *More details required.* -- `diff_conservation` (only used in SingleResidueVariant queries): Subtraction of wildtype conservation from the variant conservation (float). +- `diff_conservation` (only used in SingleResidueVariant queries): Subtraction of wildtype conservation from the variant conservation (float). ### Protein context features: @@ -85,14 +85,14 @@ These features are only calculated for ProteinProteinInterface queries. - `irc_nonpolar_nonpolar`, `irc_nonpolar_polar`, `irc_nonpolar_negative`, `irc_nonpolar_positive`, `irc_polar_polar`, `irc_polar_negative`, `irc_polar_positive`, `irc_negative_negative`, `irc_positive_positive`, `irc_negative_positive`: As above, but for specific residue polarity pairings. -## Default edge features +## Default edge features ### Contact features: `deeprank2.features.contact` These features relate to relationships between individual nodes. For atomic graphs, when features relate to residues then _all_ atoms of one residue receive the feature value for that residue. #### Distance: -- `distance`: Interatomic distance between atoms in Å, computed from the xyz atomic coordinates taken from the .pdb file (float). For residue graphs, the the minimum distance between any atom of each residues is used. +- `distance`: Interatomic distance between atoms in Å, computed from the xyz atomic coordinates taken from the .pdb file (float). For residue graphs, the the minimum distance between any atom of each residues is used. #### Structure: These features relate to the structural relationship between nodes. @@ -101,7 +101,7 @@ These features relate to the structural relationship between nodes. - `covalent`: Boolean indicating whether nodes are covalently bound (1) or not (0). Note that covalency is not directly assessed, but any edge with a maximum distance of 2.1 Å is considered covalent. #### Nonbond energies: -These features measure nonbond energy potentials between nodes. +These features measure nonbond energy potentials between nodes. For residue graphs, the pairwise sum of potentials for all atoms from each residue is used. Note that no distance cutoff is used and the radius of influence is assumed to be infinite, although the potentials tends to 0 at large distance. Also edges are only assigned within a given cutoff radius when graphs are created. Nonbond energies are set to 0 for any atom pairs (on the same chain) that are within a cutoff radius of 3.6 Å, as these are assumed to be covalent neighbors or linked by no more than 2 covalent bonds (i.e. 1-3 pairs). diff --git a/tests/features/__init__.py b/tests/features/__init__.py index 7e61216bb..ea7cab7a6 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Literal, Optional, Tuple, Union +from typing import Literal from pdb2sql import pdb2sql @@ -28,10 +28,10 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq detail: Literal['atomic', 'residue'], interaction_radius: float, max_edge_distance: float, - central_res: Optional[int] = None, - variant: Optional[AminoAcid] = None, - chain_ids: Optional[Union[str, Tuple[str, str]]] = None, - ) -> Tuple[Graph, Union[SingleResidueVariant, None]]: + central_res: int | None = None, + variant: AminoAcid |None = None, + chain_ids: str | tuple[str, str] | None = None, +) -> tuple[Graph, SingleResidueVariant | None]: """ Creates a Graph object for feature tests. @@ -40,12 +40,12 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq detail (Literal['atomic', 'residue']): Level of detail. interaction_radius (float): max distance to include in graph. max_edge_distance (float): max distance to create an edge. - central_res (Optional[int], optional): Residue to center a single-chain graph around. + central_res (int | None, optional): Residue to center a single-chain graph around. Use None to create a 2-chain graph, or any value for a single-chain graph Defaults to None. - variant (Optional[AminoAcid], optional): Amino acid to use as a variant amino acid. + variant (AminoAcid | None, optional): Amino acid to use as a variant amino acid. Defaults to None. - chain_ids (Optional[Union[str, Tuple[str, str]]], optional): Explicitly specify which chain(s) to use. + chain_ids (str | tuple[str, str] | None, optional): Explicitly specify which chain(s) to use. Defaults to None, which will use the first (two) chain(s) from the structure. Raises: diff --git a/tests/features/test_contact.py b/tests/features/test_contact.py index 92ff862f3..95d40dba0 100644 --- a/tests/features/test_contact.py +++ b/tests/features/test_contact.py @@ -1,7 +1,9 @@ -from typing import Tuple from uuid import uuid4 import numpy as np +from pdb2sql import pdb2sql + +from deeprank2.domain import edgestorage as Efeat from deeprank2.features.contact import (add_features, covalent_cutoff, cutoff_13, cutoff_14) from deeprank2.molstruct.atom import Atom @@ -9,9 +11,6 @@ from deeprank2.molstruct.structure import Chain from deeprank2.utils.buildgraph import get_structure from deeprank2.utils.graph import Edge, Graph -from pdb2sql import pdb2sql - -from deeprank2.domain import edgestorage as Efeat def _get_atom(chain: Chain, residue_number: int, atom_name: str) -> Atom: @@ -32,14 +31,14 @@ def _wrap_in_graph(edge: Edge): def _get_contact( # pylint: disable=too-many-arguments - pdb_id: str, - residue_num1: int, - atom_name1: str, - residue_num2: int, - atom_name2: str, - residue_level: bool = False, - chains: Tuple[str,str] = None, - ) -> Edge: + pdb_id: str, + residue_num1: int, + atom_name1: str, + residue_num2: int, + atom_name2: str, + residue_level: bool = False, + chains: tuple[str,str] = None, +) -> Edge: pdb_path = f"tests/data/pdb/{pdb_id}/{pdb_id}.pdb" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 98223acd0..92406b205 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,7 +3,6 @@ import warnings from shutil import rmtree from tempfile import mkdtemp -from typing import List, Union import h5py import numpy as np @@ -18,14 +17,17 @@ node_feats = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM] def _compute_features_manually( # noqa: MC0001, pylint: disable=too-many-locals - hdf5_path: str, - features_transform: dict, - feat: str - ): - # This function returns the feature specified read from the hdf5 file, - # after applying manually features_transform dict. It returns its mean - # and its std after having applied eventual transformations. - # Multi-channels features are returned as an array with multiple channels. + hdf5_path: str, + features_transform: dict, + feat: str +): + """ + This function returns the feature specified read from the hdf5 file, + after applying manually features_transform dict. It returns its mean + and its std after having applied eventual transformations. + Multi-channels features are returned as an array with multiple channels. + """ + with h5py.File(hdf5_path, 'r') as f: entry_names = [entry for entry, _ in f.items()] @@ -122,9 +124,9 @@ def _compute_features_with_get( return features_dict def _check_inherited_params( - inherited_params: List[str], - dataset_train: Union[GraphDataset, GridDataset], - dataset_test: Union[GraphDataset, GridDataset], + inherited_params: list[str], + dataset_train: GraphDataset | GridDataset, + dataset_test: GraphDataset | GridDataset, ): dataset_train_vars = vars(dataset_train) dataset_test_vars = vars(dataset_test) diff --git a/tests/test_query.py b/tests/test_query.py index 4c941ab17..b117153a1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,7 +1,6 @@ import os import shutil from tempfile import mkdtemp, mkstemp -from typing import List import h5py import numpy as np @@ -23,8 +22,8 @@ def _check_graph_makes_sense( g: Graph, - node_feature_names: List[str], - edge_feature_names: List[str], + node_feature_names: list[str], + edge_feature_names: list[str], ): assert len(g.nodes) > 0, "no nodes" diff --git a/tests/test_querycollection.py b/tests/test_querycollection.py index 258357cdc..2b713e4ba 100644 --- a/tests/test_querycollection.py +++ b/tests/test_querycollection.py @@ -3,7 +3,6 @@ from shutil import rmtree from tempfile import mkdtemp from types import ModuleType -from typing import List, Union import h5py import pytest @@ -20,7 +19,7 @@ def _querycollection_tester( # pylint: disable=dangerous-default-value query_type: str, n_queries: int = 3, - feature_modules: Union[ModuleType, List[ModuleType]] = [components, contact], + feature_modules: ModuleType | list[ModuleType] = [components, contact], cpu_count: int = 1, combine_output: bool = True, ): @@ -86,14 +85,14 @@ def _querycollection_tester( # pylint: disable=dangerous-default-value def _assert_correct_modules( output_paths: str, - features: str | List[str], + features: str | list[str], absent: str, ): """Helper function to assert inclusion of correct features Args: output_paths (str): output_paths as returned from _querycollection_tester - features (Union[str, List[str]]): feature(s) that should be present + features (str | list[str]]: feature(s) that should be present absent (str): feature that should be absent """ @@ -137,8 +136,9 @@ def test_querycollection_process(): def test_querycollection_process_single_feature_module(): - """ - Tests processing for generating from a single feature module for following input types: ModuleType, List[ModuleType] str, List[str] + """Test processing for generating from a single feature module. + + Tested for following input types: ModuleType, list[ModuleType] str, list[str] """ for query_type in ['ppi', 'srv']: