From b5767f6b31592c30b933b232e0bf84f9d84f85dc Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 07:12:47 +0100 Subject: [PATCH 1/7] ci: set up ruff linter --- .github/workflows/linting.yml | 4 +- .prospector.yml | 50 ----------- pyproject.toml | 156 ++++++++++++++++++++++++++++------ 3 files changed, 133 insertions(+), 77 deletions(-) delete mode 100644 .prospector.yml diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 4c0e1c9f4..8f78bb8b7 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -45,5 +45,5 @@ jobs: with: python-version: ${{ matrix.python-version }} extras-require: test - - name: Check style against standards using prospector - run: prospector --die-on-tool-error + - name: Check style against standards using ruff + run: ruff . diff --git a/.prospector.yml b/.prospector.yml deleted file mode 100644 index 75f1e737b..000000000 --- a/.prospector.yml +++ /dev/null @@ -1,50 +0,0 @@ -# prospector configuration file - ---- - -output-format: grouped - -strictness: medium -doc-warnings: false -test-warnings: true -member-warnings: false - -ignore-paths: - - docs - - reduce - -ignore-patterns: - - setup.py - -pyroma: - run: true - # pyroma gives errors in the setup.py file, - # thus we disable here these errors. - # This should not be happening, because - # prospector should be ignoring the setup.py - # file (see ignore-patterns above) - disable: - - PYR10 - - PYR11 - - PYRUNKNOWN - -pycodestyle: - full: true - options: - max-line-length: 159 - -pydocstyle: - disable: [ - # Disable because not part of PEP257 official convention: - # see http://pep257.readthedocs.io/en/latest/error_codes.html - D203, # 1 blank line required before class docstring - D212, # Multi-line docstring summary should start at the first line - D213, # Multi-line docstring summary should start at the second line - D404, # First word of the docstring should not be This - ] - -pylint: - disable: [ - logging-fstring-interpolation, - logging-not-lazy, - ] diff --git a/pyproject.toml b/pyproject.toml index cbc10f9ea..a141860c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,24 +14,25 @@ keywords = [ "protein-protein interfaces", "missense variants", "deep learning", - "pytorch"] + "pytorch", +] authors = [ - {name = "Giulia Crocioni", email = "g.crocioni@esciencecenter.nl"}, - {name = "Coos Baakman", email = "coos.baakman@radboudumc.nl"}, - {name = "Dani Bodor", email = "d.bodor@esciencecenter.nl"}, - {name = "Daniel Rademaker"}, - {name = "Gayatri Ramakrishnan"}, - {name = "Sven van der Burg"}, - {name = "Li Xue"}, - {name = "Daniil Lepikhov"}, - ] -license = {text = "Apache-2.0 license"} + { name = "Giulia Crocioni", email = "g.crocioni@esciencecenter.nl" }, + { name = "Coos Baakman", email = "coos.baakman@radboudumc.nl" }, + { name = "Dani Bodor", email = "d.bodor@esciencecenter.nl" }, + { name = "Daniel Rademaker" }, + { name = "Gayatri Ramakrishnan" }, + { name = "Sven van der Burg" }, + { name = "Li Xue" }, + { name = "Daniil Lepikhov" }, +] +license = { text = "Apache-2.0 license" } classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10" + "Programming Language :: Python :: 3.10", ] dependencies = [ "tables >= 3.8.0", @@ -49,15 +50,15 @@ dependencies = [ "tqdm >= 4.63.0", "freesasa >= 2.1.0", "tensorboard >= 0.9.0", - "protobuf >= 3.20.1" + "protobuf >= 3.20.1", + "ruff >= 0.1.13", + "dill", ] [project.optional-dependencies] # development dependency groups test = [ "pytest >= 7.4.0", - "pylint <= 2.17.5", - "prospector[with_pyroma] <= 1.10.2", "bump2version", "coverage", "pycodestyle", @@ -65,11 +66,7 @@ test = [ "pytest-runner", "coveralls", ] -publishing = [ - "build", - "twine", - "wheel", -] +publishing = ["build", "twine", "wheel"] [project.urls] Documentation = "https://deeprank2.readthedocs.io/en/latest/?badge=latest" @@ -85,8 +82,117 @@ include = ["deeprank2*"] exclude = ["tests*", "*tests.*", "*tests"] [tool.setuptools.package-data] -"*" = [ - "*.xlsx", - "*.param", - "*.top", - "*residue-classes"] +"*" = ["*.xlsx", "*.param", "*.top", "*residue-classes"] + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +extend-exclude = ["docs", "reduce"] +line-length = 159 +select = [ + "F", # Pyflakes + "E", # pycodestyle (error) + "W", # pycodestyle (warning) + "I", # isort + "D", # pydocstyle + "UP", # pyupgrade + "SIM", # simplify + "C4", # flake8-comprehensions + "S", # flake8-bandit + "PGH", # pygrep-hooks + "BLE", # blind-except + "FBT003", # boolean-positional-value-in-call + "B", # flake8-bugbear + "Q", # flake8-quotes + "PLR", # pylint refactoring + "ARG", # flake8-unused-arguments + "SLF001", # Private member accessed + "PIE", # flake8-pie + "RET", # flaske8-return + "PT", # pytest + "TID", # imports + "TCH", # imports + "PD", # pandas + "NPY", # numpy + "PL", # pylint + "RUF", # ruff rtecommendations + "PERF", # performance + "TRY", # try blocks + "ERA", # commented out code + # other linting conventions + "FLY", + "AIR", + "YTT", + "ASYNC", + "A", + "DTZ", + "DJ", + "FA", + "ISC", + "ICN", + "G", + "INP", + "PYI", + "Q", + "RSE102", + "SLOT", + "INT", + # The following are unrealistic for this code base + # "PTH" # flake8-use-pathlib + # "ANN", # annotations + # "N", # naming conventions + # "C901", # mccabe complexity +] +ignore = [ + "PLR0912", # Too many branches, + "PLR0913", #Too many arguments in function definition + "B028", # No explicit `stacklevel` keyword argument found in + "PLR2004", # Magic value used in comparison + "S105", # Possible hardcoded password + "S311", # insecure random generators + "PT011", # pytest-raises-too-broad + "SIM108", # Use ternary operator + "TRY003", # Long error messages + # Missing docstrings Documentation + "D100", # Missing module docstring + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing public package docstring + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` + # Rules irrelevant to the Google style + "D203", # 1 blank line required before class docstring + "D204", + "D212", # Multi-line docstring summary should start at the first line + "D213", # Multi-line docstring summary should start at the second line + "D215", + "D400", + "D401", + "D404", # First word of the docstring should not be This + "D406", + "D407", + "D408", + "D409", + "D413", +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [ + "F401", +] # unused imports (it's annoying if they automatically disappear while editing code + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101"] + +[tool.ruff.lint] +extend-safe-fixes = [ + "D415", # First line should end with a period, question mark, or exclamation point + "D300", # Use triple double quotes `"""` + "D200", # One-line docstring should fit on one line + "TCH", # type checking only imports + "ISC001", +] + +[tool.ruff.isort] +known-first-party = ["deeprank2"] From 01ccd9094b0b69c858162a3cf15d3ad71376fb83 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 07:13:12 +0100 Subject: [PATCH 2/7] docs: reference new linter in documentation --- CONTRIBUTING.rst | 2 ++ README.dev.md | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 31d422f57..f6a3dfc95 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -37,6 +37,8 @@ You want to make some kind of change to the code base #. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest main commit. While working on your feature branch, make sure to stay up to date with the main branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions `here `__ and `here `__); #. make sure the existing tests still work by running ``python setup.py test``; #. add your own tests (if necessary); +#. ensure the code is correctly linted (`ruff .`) and formatted (`ruff format .`); +#. see our `developer's readme `` for detailed information on our style conventions, etc.; #. update or expand the documentation; #. `push `_ your feature branch to (your fork of) the DeepRank2 repository on GitHub; #. create the pull request, e.g. following the instructions `here `__. diff --git a/README.dev.md b/README.dev.md index 46069c62a..7b00c6847 100644 --- a/README.dev.md +++ b/README.dev.md @@ -42,9 +42,13 @@ coverage report `coverage` can also generate output in HTML and other formats; see `coverage help` for more information. -## Linting +## Linting and Formatting -We use [prospector](https://pypi.org/project/prospector/) with pyroma for linting. For running it locally, use `prospector` or `prospector ` for specific files/folders. +We use [ruff](https://docs.astral.sh/ruff/) for linting, sorting imports and formatting of python (notebook) files. The configurations of `ruff` are set in [pyproject.toml](pyproject.toml) file. + +If you are using VS code, please install and activate the [Ruff extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) to automatically format and check linting. + +Otherwise, please ensure check both linting (`ruff fix .`) and formatting (`ruff format .`) before requesting a review. ## Versioning @@ -56,26 +60,26 @@ We use a [Git Flow](https://nvie.com/posts/a-successful-git-branching-model/)-in - `main` — this branch contains production (stable) code. All development code is merged into `main` in sometime. - `dev` — this branch contains pre-production code. When the features are finished then they are merged into `dev`. During the development cycle, three main supporting branches are used: -- Feature branches - Branches that branch off from `dev` and must merge into `dev`: used to develop new features for the upcoming releases. +- Feature branches - Branches that branch off from `dev` and must merge into `dev`: used to develop new features for the upcoming releases. - Hotfix branches - Branches that branch off from `main` and must merge into `main` and `dev`: necessary to act immediately upon an undesired status of `main`. - Release branches - Branches that branch off from `dev` and must merge into `main` and `dev`: support preparation of a new production release. They allow many minor bug to be fixed and preparation of meta-data for a release. -### Development conventions +### Development conventions - Branching - When creating a new branch, please use the following convention: `__`. - - Always branch from `dev` branch, unless there is the need to fix an undesired status of `main`. See above for more details about the branching workflow adopted. + - Always branch from `dev` branch, unless there is the need to fix an undesired status of `main`. See above for more details about the branching workflow adopted. - Pull Requests - When creating a pull request, please use the following convention: `: `. Example _types_ are `fix:`, `feat:`, `build:`, `chore:`, `ci:`, `docs:`, `style:`, `refactor:`, `perf:`, `test:`, and others based on the [Angular convention](https://github.com/angular/angular/blob/22b96b9/CONTRIBUTING.md#-commit-message-guidelines). -## Making a release +## Making a release 1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files such as the current one, fix minor bugs if necessary). -2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning). +2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning). 3. Verify that the information in `CITATION.cff` is correct (update the release date), and that `.zenodo.json` contains equivalent data. 5. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests). 6. Go to https://github.com/DeepRank/deeprank2/releases and draft a new release; create a new tag for the release, generate release notes automatically and adjust them, and finally publish the release as latest. This will trigger [a GitHub action](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. -7. Update the doi in `CITATION.cff` with the one corresponding to the new release. +7. Update the doi in `CITATION.cff` with the one corresponding to the new release. ## UML From c56854e2b974986fae4ae4fe6ccb52e44fec538c Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 07:13:43 +0100 Subject: [PATCH 3/7] style: VS code automatic linting/formatting --- .vscode/settings.json | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 4b74dfbcf..f0f387d5e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,15 +1,19 @@ { + // Python "[python]": { + "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.fixAll": "explicit" }, - "files.trimTrailingWhitespace": true, + "editor.defaultFormatter": "charliermarsh.ruff" }, + "autoDocstring.docstringFormat": "google", - "python.linting.prospectorEnabled": true, + // Notebooks "notebook.lineNumbers": "on", - - "[*.yml]": { - "files.trimTrailingWhitespace": true, + "notebook.formatOnSave.enabled": true, + "notebook.codeActionsOnSave": { + "notebook.source.fixAll": "explicit", }, + "notebook.diff.ignoreMetadata": true, } From bf66cac8eb4698c7f47c0be76ccd31bc21f9ec16 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 07:14:31 +0100 Subject: [PATCH 4/7] style: implement new linter/formatter throughout code base --- deeprank2/__init__.py | 2 +- deeprank2/dataset.py | 591 ++++++------ deeprank2/domain/aminoacidlist.py | 387 ++++---- deeprank2/domain/edgestorage.py | 12 +- deeprank2/domain/gridstorage.py | 2 - deeprank2/domain/losstypes.py | 56 +- deeprank2/domain/nodestorage.py | 68 +- deeprank2/domain/targetstorage.py | 12 +- deeprank2/features/components.py | 7 +- deeprank2/features/conservation.py | 6 +- deeprank2/features/contact.py | 39 +- deeprank2/features/exposure.py | 22 +- deeprank2/features/irc.py | 29 +- deeprank2/features/secondary_structure.py | 67 +- deeprank2/features/surfacearea.py | 78 +- deeprank2/molstruct/aminoacid.py | 12 +- deeprank2/molstruct/atom.py | 24 +- deeprank2/molstruct/pair.py | 10 +- deeprank2/molstruct/residue.py | 39 +- deeprank2/molstruct/structure.py | 28 +- deeprank2/neuralnets/cnn/model3d.py | 11 +- deeprank2/neuralnets/gnn/alignmentnet.py | 36 +- deeprank2/neuralnets/gnn/foutnet.py | 20 +- deeprank2/neuralnets/gnn/ginet.py | 27 +- deeprank2/neuralnets/gnn/ginet_nocluster.py | 23 +- deeprank2/neuralnets/gnn/naive_gnn.py | 11 +- deeprank2/neuralnets/gnn/sgat.py | 20 +- deeprank2/query.py | 190 ++-- deeprank2/tools/target.py | 63 +- deeprank2/trainer.py | 413 +++++---- deeprank2/utils/buildgraph.py | 49 +- deeprank2/utils/community_pooling.py | 34 +- deeprank2/utils/earlystopping.py | 33 +- deeprank2/utils/exporters.py | 166 ++-- deeprank2/utils/graph.py | 152 ++-- deeprank2/utils/grid.py | 86 +- deeprank2/utils/parsing/__init__.py | 62 +- deeprank2/utils/parsing/patch.py | 16 +- deeprank2/utils/parsing/pssm.py | 13 +- deeprank2/utils/parsing/residue.py | 27 +- deeprank2/utils/parsing/top.py | 10 +- deeprank2/utils/parsing/vdwparam.py | 11 +- deeprank2/utils/pssmdata.py | 9 +- docs/conf.py | 138 +-- tests/__init__.py | 0 tests/domain/__init__.py | 0 tests/domain/test_aminoacidlist.py | 9 +- tests/domain/test_forcefield.py | 19 +- tests/features/__init__.py | 42 +- tests/features/test_components.py | 4 +- tests/features/test_conservation.py | 10 +- tests/features/test_contact.py | 169 ++-- tests/features/test_exposure.py | 12 +- tests/features/test_irc.py | 19 +- tests/features/test_secondary_structure.py | 80 +- tests/features/test_surfacearea.py | 22 +- tests/molstruct/__init__.py | 0 tests/molstruct/test_structure.py | 7 +- tests/perf/__init__.py | 0 tests/perf/ppi_perf.py | 80 +- tests/perf/srv_perf.py | 160 ++-- tests/test_dataset.py | 947 +++++++++++--------- tests/test_integration.py | 183 ++-- tests/test_query.py | 178 ++-- tests/test_querycollection.py | 206 +++-- tests/test_set_lossfunction.py | 132 +-- tests/test_trainer.py | 546 ++++++----- tests/tools/__init__.py | 0 tests/tools/test_target.py | 37 +- tests/utils/__init__.py | 0 tests/utils/test_buildgraph.py | 14 +- tests/utils/test_community_pooling.py | 20 +- tests/utils/test_earlystopping.py | 15 +- tests/utils/test_exporters.py | 58 +- tests/utils/test_graph.py | 89 +- tests/utils/test_grid.py | 4 +- tests/utils/test_pssmdata.py | 7 +- 77 files changed, 3149 insertions(+), 3031 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/domain/__init__.py create mode 100644 tests/molstruct/__init__.py create mode 100644 tests/perf/__init__.py create mode 100644 tests/tools/__init__.py create mode 100644 tests/utils/__init__.py diff --git a/deeprank2/__init__.py b/deeprank2/__init__.py index b62a3e51a..4eabd0b3f 100644 --- a/deeprank2/__init__.py +++ b/deeprank2/__init__.py @@ -1 +1 @@ -__version__ = "2.1.2" \ No newline at end of file +__version__ = "2.1.2" diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index c78ef9d67..0aace279f 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -7,7 +7,7 @@ import re import sys import warnings -from typing import Literal, Union +from typing import Literal import h5py import matplotlib.pyplot as plt @@ -23,28 +23,31 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets +# ruff: noqa: PYI051 (redundant-literal-union), the literal is a special case, while the str is generic + _log = logging.getLogger(__name__) class DeeprankDataset(Dataset): - def __init__(self, # pylint: disable=too-many-arguments - hdf5_path: str | list[str], - subset: list[str] | None, - train_source: str | GridDataset | GraphDataset | None, - target: str | None, - target_transform: bool | None, - target_filter: dict[str, str] | None, - task: str | None, - classes: list[str] | list[int] | list[float] | None, - use_tqdm: bool, - root: str, - check_integrity: bool + def __init__( + self, + hdf5_path: str | list[str], + subset: list[str] | None, + train_source: str | GridDataset | GraphDataset | None, + target: str | None, + target_transform: bool | None, + target_filter: dict[str, str] | None, + task: str | None, + classes: list[str] | list[int] | list[float] | None, + use_tqdm: bool, + root: str, + check_integrity: bool, ): - """Parent class of :class:`GridDataset` and :class:`GraphDataset` which inherits from :class:`torch_geometric.data.dataset.Dataset`. + """Parent class of :class:`GridDataset` and :class:`GraphDataset`. + This class inherits from :class:`torch_geometric.data.dataset.Dataset`. More detailed information about the parameters can be found in :class:`GridDataset` and :class:`GraphDataset`. """ - super().__init__(root) if isinstance(hdf5_path, str): @@ -82,41 +85,47 @@ def __init__(self, # pylint: disable=too-many-arguments # get the device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def _check_and_inherit_train(self, data_type: Union[GridDataset, GraphDataset], inherited_params): - """Check if the pre-trained model or training set provided are valid for validation and/or testing, and inherit the parameters. - """ - if isinstance(self.train_source, str): # pylint: disable=too-many-nested-blocks + def _check_and_inherit_train( + self, + data_type: GridDataset | GraphDataset, + inherited_params, + ): + """Check if the pre-trained model or training set provided are valid for validation and/or testing, and inherit the parameters.""" + if isinstance(self.train_source, str): try: if torch.cuda.is_available(): data = torch.load(self.train_source) else: - data = torch.load(self.train_source, map_location=torch.device('cpu')) + data = torch.load(self.train_source, map_location=torch.device("cpu")) if data["data_type"] is not data_type: - raise TypeError (f"""The pre-trained model has been trained with data of type {data["data_type"]}, but you are trying - to define a {data_type}-class validation/testing dataset. Please provide a valid DeepRank2 - model trained with {data_type}-class type data, or define the dataset using the appropriate class.""") + raise TypeError( + f"The pre-trained model has been trained with data of type {data['data_type']}, but you are trying \n\t" + f"to define a {data_type}-class validation/testing dataset. Please provide a valid DeepRank2 \n\t" + f"model trained with {data_type}-class type data, or define the dataset using the appropriate class." + ) if data_type is GraphDataset: self.train_means = data["means"] self.train_devs = data["devs"] # convert strings in 'transform' key to lambda functions if data["features_transform"]: - for _, key in data["features_transform"].items(): - if key['transform'] is None: + for key in data["features_transform"].values(): + if key["transform"] is None: continue - key['transform'] = eval(key['transform']) # pylint: disable=eval-used + key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 (suspicious-eval-usage) except pickle.UnpicklingError as e: - raise ValueError("""The path provided to `train_source` is not a valid DeepRank2 pre-trained model. - Please provide a valid path to a DeepRank2 pre-trained model.""") from e + raise ValueError("The path provided to `train_source` is not a valid DeepRank2 pre-trained model.") from e elif isinstance(self.train_source, data_type): data = self.train_source if data_type is GraphDataset: self.train_means = self.train_source.means self.train_devs = self.train_source.devs else: - raise TypeError(f"""The train data provided is type: {type(self.train_source)} - Please provide a valid training {data_type} or the path to a valid DeepRank2 pre-trained model.""") + raise TypeError( + f"The train data provided is invalid: {type(self.train_source)}.\n\t" + f"Please provide a valid training {data_type} or the path to a valid DeepRank2 pre-trained model." + ) - #match parameters with the ones in the training set + # match parameters with the ones in the training set self._check_inherited_params(inherited_params, data) def _check_hdf5_files(self): @@ -130,7 +139,7 @@ def _check_hdf5_files(self): if len(entry_names) == 0: _log.info(f" -> {hdf5_path} is empty ") to_be_removed.append(hdf5_path) - except Exception as e: + except Exception as e: # noqa: BLE001, PERF203 (blind-except, try-except-in-loop) _log.error(e) _log.info(f" -> {hdf5_path} is corrupted ") to_be_removed.append(hdf5_path) @@ -139,7 +148,6 @@ def _check_hdf5_files(self): self.hdf5_paths.remove(hdf5_path) def _check_task_and_classes(self, task: str, classes: str | None = None): - if self.target in [targets.IRMSD, targets.LRMSD, targets.FNAT, targets.DOCKQ]: self.task = targets.REGRESS @@ -150,23 +158,21 @@ def _check_task_and_classes(self, task: str, classes: str | None = None): self.task = task if self.task not in [targets.CLASSIF, targets.REGRESS] and self.target is not None: - raise ValueError( - f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}") + raise ValueError(f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}") if task != self.task and task is not None: - warnings.warn(f"Target {self.target} expects {self.task}, but was set to task {task} by user.\n" + - f"User set task is ignored and {self.task} will be used.") + warnings.warn( + f"Target {self.target} expects {self.task}, but was set to task {task} by user.\nUser set task is ignored and {self.task} will be used." + ) if self.task == targets.CLASSIF: if classes is None: self.classes = [0, 1] - _log.info(f'Target classes set to: {self.classes}') + _log.info(f"Target classes set to: {self.classes}") else: self.classes = classes - self.classes_to_index = { - class_: index for index, class_ in enumerate(self.classes) - } + self.classes_to_index = {class_: index for index, class_ in enumerate(self.classes)} else: self.classes = None self.classes_to_index = None @@ -176,24 +182,25 @@ def _check_inherited_params( inherited_params: list[str], data: dict | GraphDataset | GridDataset, ): - """"Check if the parameters for validation and/or testing are the same as in the pre-trained model or training set provided. + """Check if the parameters for validation and/or testing are the same as in the pre-trained model or training set provided. Args: inherited_params (List[str]): List of parameters that need to be checked for inheritance. data (Union[dict, class:`GraphDataset`, class:`GridDataset`]): The parameters in `inherited_param` will be inherited from the information contained in `data`. """ - self_vars = vars(self) if not isinstance(data, dict): data = vars(data) for param in inherited_params: - if (self_vars[param] != data[param]): - if (self_vars[param] != self.default_vars[param]): - _log.warning(f"The {param} parameter set here is: {self_vars[param]}, " + - f"which is not equivalent to the one in the training phase: {data[param]}./n" + - f"Overwriting {param} parameter with the one used in the training phase.") + if self_vars[param] != data[param]: + if self_vars[param] != self.default_vars[param]: + _log.warning( + f"The {param} parameter set here is: {self_vars[param]}, " + f"which is not equivalent to the one in the training phase: {data[param]}./n" + f"Overwriting {param} parameter with the one used in the training phase." + ) setattr(self, param, data[param]) def _create_index_entries(self): @@ -224,14 +231,13 @@ def _create_index_entries(self): else: entry_names = [entry_name for entry_name in self.subset if entry_name in list(hdf5_file.keys())] - #skip self._filter_targets when target_filter is None, improve performance using list comprehension. + # skip self._filter_targets when target_filter is None, improve performance using list comprehension. if self.target_filter is None: self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names] else: - self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names \ - if self._filter_targets(hdf5_file[entry_name])] + self.index_entries += [(hdf5_path, entry_name) for entry_name in entry_names if self._filter_targets(hdf5_file[entry_name])] - except Exception: + except Exception: # noqa: BLE001 (blind-except) _log.exception(f"on {hdf5_path}") def _filter_targets(self, grp: h5py.Group) -> bool: @@ -249,33 +255,28 @@ def _filter_targets(self, grp: h5py.Group) -> bool: Raises: ValueError: If an unsuported condition is provided. """ - if self.target_filter is None: return True for target_name, target_condition in self.target_filter.items(): - present_target_names = list(grp[targets.VALUES].keys()) if target_name in present_target_names: - # If we have a given target_condition, see if it's met. if isinstance(target_condition, str): - operation = target_condition target_value = grp[targets.VALUES][target_name][()] for operator_string in [">", "<", "==", "<=", ">=", "!="]: operation = operation.replace(operator_string, f"{target_value}" + operator_string) - if not eval(operation): # pylint: disable=eval-used + if not eval(operation): # noqa: S307, PGH001 (suspicious-eval-usage) return False elif target_condition is not None: raise ValueError("Conditions not supported", target_condition) else: - _log.warning(f" :Filter {target_name} not found for entry {grp}\n" - f" :Filter options are: {present_target_names}") + _log.warning(f" :Filter {target_name} not found for entry {grp}\n :Filter options are: {present_target_names}") return True def len(self) -> int: @@ -286,8 +287,8 @@ def len(self) -> int: """ return len(self.index_entries) - def hdf5_to_pandas( # noqa: MC0001, pylint: disable=too-many-locals - self + def hdf5_to_pandas( + self, ) -> pd.DataFrame: """Loads features data from the HDF5 files into a Pandas DataFrame in the attribute `df` of the class. @@ -295,13 +296,11 @@ def hdf5_to_pandas( # noqa: MC0001, pylint: disable=too-many-locals :class:`pd.DataFrame`: Pandas DataFrame containing the selected features as columns per all data points in hdf5_path files. """ - df_final = pd.DataFrame() for fname in self.hdf5_paths: - with h5py.File(fname, 'r') as f: - - entry_name = list(f.keys())[0] + with h5py.File(fname, "r") as f: + entry_name = next(iter(f.keys())) if self.subset is not None: entry_names = [entry for entry, _ in f.items() if entry in self.subset] @@ -309,48 +308,44 @@ def hdf5_to_pandas( # noqa: MC0001, pylint: disable=too-many-locals entry_names = [entry for entry, _ in f.items()] df_dict = {} - df_dict['id'] = entry_names + df_dict["id"] = entry_names for feat_type in self.features_dict: for feat in self.features_dict[feat_type]: # reset transform for each feature transform = None if self.features_transform: - transform = self.features_transform.get('all', {}).get('transform') + transform = self.features_transform.get("all", {}).get("transform") if (transform is None) and (feat in self.features_transform): - transform = self.features_transform.get(feat, {}).get('transform') - #Check the number of channels the features have + transform = self.features_transform.get(feat, {}).get("transform") + # Check the number of channels the features have if f[entry_name][feat_type][feat][()].ndim == 2: for i in range(f[entry_name][feat_type][feat][:].shape[1]): - df_dict[feat + '_' + str(i)] = [f[entry_name][feat_type][feat][:][:,i] for entry_name in entry_names] - #apply transformation for each channel in this feature + df_dict[feat + "_" + str(i)] = [f[entry_name][feat_type][feat][:][:, i] for entry_name in entry_names] + # apply transformation for each channel in this feature if transform: - df_dict[feat + '_' + str(i)] = [transform(row) for row in df_dict[feat + '_' + str(i)]] + df_dict[feat + "_" + str(i)] = [transform(row) for row in df_dict[feat + "_" + str(i)]] else: df_dict[feat] = [ - f[entry_name][feat_type][feat][:] - if f[entry_name][feat_type][feat][()].ndim == 1 - else f[entry_name][feat_type][feat][()] for entry_name in entry_names] - #apply transformation + f[entry_name][feat_type][feat][:] if f[entry_name][feat_type][feat][()].ndim == 1 else f[entry_name][feat_type][feat][()] + for entry_name in entry_names + ] + # apply transformation if transform: - df_dict[feat]=[transform(row) for row in df_dict[feat]] - - df = pd.DataFrame(data=df_dict) - - df_final = pd.concat([df_final, df]) + df_dict[feat] = [transform(row) for row in df_dict[feat]] - df_final.reset_index(drop=True, inplace=True) - self.df = df_final + df_temp = pd.DataFrame(data=df_dict) + df_concat = pd.concat([df_final, df_temp]) + self.df = df_concat.reset_index(drop=True) + return self.df - return df_final - - def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless-suppression - self, - features: str | list[str], - fname: str = 'features_hist.png', - bins: int | list[float] | str = 10, - figsize: tuple = (15, 15), - log: bool = False + def save_hist( + self, + features: str | list[str], + fname: str = "features_hist.png", + bins: int | list[float] | str = 10, + figsize: tuple = (15, 15), + log: bool = False, ): """After having generated a pd.DataFrame using hdf5_to_pandas method, histograms of the features can be saved in an image. @@ -373,57 +368,64 @@ def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless- if not isinstance(features, list): features = [features] - features_df = [col for feat in features for col in self.df.columns.values.tolist() if feat in col] + features_df = [col for feat in features for col in self.df.columns.to_numpy().tolist() if feat in col] means = [ - round(np.concatenate(self.df[feat].values).mean(), 1) if isinstance(self.df[feat].values[0], np.ndarray) \ - else round(self.df[feat].values.mean(), 1) \ - for feat in features_df] + round(np.concatenate(self.df[feat].to_numpy()).mean(), 1) + if isinstance(self.df[feat].to_numpy()[0], np.ndarray) + else round(self.df[feat].to_numpy().mean(), 1) + for feat in features_df + ] devs = [ - round(np.concatenate(self.df[feat].values).std(), 1) if isinstance(self.df[feat].values[0], np.ndarray) \ - else round(self.df[feat].values.std(), 1) \ - for feat in features_df] + round(np.concatenate(self.df[feat].to_numpy()).std(), 1) + if isinstance(self.df[feat].to_numpy()[0], np.ndarray) + else round(self.df[feat].to_numpy().std(), 1) + for feat in features_df + ] if len(features_df) > 1: - fig, axs = plt.subplots(len(features_df), figsize=figsize) for row, feat in enumerate(features_df): - if isinstance(self.df[feat].values[0], np.ndarray): + if isinstance(self.df[feat].to_numpy()[0], np.ndarray): if log: - log_data = np.log(np.concatenate(self.df[feat].values)) + log_data = np.log(np.concatenate(self.df[feat].to_numpy())) log_data[log_data == -np.inf] = 0 axs[row].hist(log_data, bins=bins) else: - axs[row].hist(np.concatenate(self.df[feat].values), bins=bins) + axs[row].hist(np.concatenate(self.df[feat].to_numpy()), bins=bins) + elif log: + log_data = np.log(self.df[feat].to_numpy()) + log_data[log_data == -np.inf] = 0 + axs[row].hist(log_data, bins=bins) else: - if log: - log_data = np.log(self.df[feat].values) - log_data[log_data == -np.inf] = 0 - axs[row].hist(log_data, bins=bins) - else: - axs[row].hist(self.df[feat].values, bins=bins) - axs[row].set(xlabel=f'{feat} (mean {means[row]}, std {devs[row]})', ylabel='Count') + axs[row].hist(self.df[feat].to_numpy(), bins=bins) + axs[row].set( + xlabel=f"{feat} (mean {means[row]}, std {devs[row]})", + ylabel="Count", + ) fig.tight_layout() elif len(features_df) == 1: fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111) - if isinstance(self.df[features_df[0]].values[0], np.ndarray): + if isinstance(self.df[features_df[0]].to_numpy()[0], np.ndarray): if log: - log_data = np.log(np.concatenate(self.df[features_df[0]].values)) + log_data = np.log(np.concatenate(self.df[features_df[0]].to_numpy())) log_data[log_data == -np.inf] = 0 ax.hist(log_data, bins=bins) else: - ax.hist(np.concatenate(self.df[features_df[0]].values), bins=bins) + ax.hist(np.concatenate(self.df[features_df[0]].to_numpy()), bins=bins) + elif log: + log_data = np.log(self.df[features_df[0]].to_numpy()) + log_data[log_data == -np.inf] = 0 + ax.hist(log_data, bins=bins) else: - if log: - log_data = np.log(self.df[features_df[0]].values) - log_data[log_data == -np.inf] = 0 - ax.hist(log_data, bins=bins) - else: - ax.hist(self.df[features_df[0]].values, bins=bins) - ax.set(xlabel=f'{features_df[0]} (mean {means[0]}, std {devs[0]})', ylabel='Count') + ax.hist(self.df[features_df[0]].values, bins=bins) + ax.set( + xlabel=f"{features_df[0]} (mean {means[0]}, std {devs[0]})", + ylabel="Count", + ) else: raise ValueError("Please provide valid features names. They must be present in the current :class:`DeeprankDataset` children instance.") @@ -433,13 +435,18 @@ def save_hist( # pylint: disable=too-many-arguments, too-many-branches, useless- plt.close(fig) def _compute_mean_std(self): - - means = {col: round(np.nanmean(np.concatenate(self.df[col].values)), 1) if isinstance(self.df[col].values[0], np.ndarray) \ - else round(np.nanmean(self.df[col].values), 1) \ - for col in self.df.columns[1:]} - devs = {col: round(np.nanstd(np.concatenate(self.df[col].values)), 1) if isinstance(self.df[col].values[0], np.ndarray) \ - else round(np.nanstd(self.df[col].values), 1) \ - for col in self.df.columns[1:]} + means = { + col: round(np.nanmean(np.concatenate(self.df[col].values)), 1) + if isinstance(self.df[col].to_numpy()[0], np.ndarray) + else round(np.nanmean(self.df[col].to_numpy()), 1) + for col in self.df.columns[1:] + } + devs = { + col: round(np.nanstd(np.concatenate(self.df[col].to_numpy())), 1) + if isinstance(self.df[col].to_numpy()[0], np.ndarray) + else round(np.nanstd(self.df[col].to_numpy()), 1) + for col in self.df.columns[1:] + } self.means = means self.devs = devs @@ -451,7 +458,7 @@ def _compute_mean_std(self): class GridDataset(DeeprankDataset): - def __init__( # pylint: disable=too-many-arguments, too-many-locals + def __init__( self, hdf5_path: str | list, subset: list[str] | None = None, @@ -464,7 +471,7 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals classes: list[str] | list[int] | list[float] | None = None, use_tqdm: bool = True, root: str = "./", - check_integrity: bool = True + check_integrity: bool = True, ): """Class to load the .HDF5 files data into grids. @@ -511,19 +518,33 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. Defaults to True. """ - super().__init__(hdf5_path, subset, train_source, target, target_transform, - target_filter, task, classes, use_tqdm, root, check_integrity) - self.default_vars = { - k: v.default - for k, v in inspect.signature(self.__init__).parameters.items() - if v.default is not inspect.Parameter.empty - } + super().__init__( + hdf5_path, + subset, + train_source, + target, + target_transform, + target_filter, + task, + classes, + use_tqdm, + root, + check_integrity, + ) + self.default_vars = {k: v.default for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty} self.default_vars["classes_to_index"] = None self.features = features self.target_transform = target_transform if train_source is not None: - self.inherited_params = ["features", "target", "target_transform", "task", "classes", "classes_to_index"] + self.inherited_params = [ + "features", + "target", + "target_transform", + "task", + "classes", + "classes_to_index", + ] self._check_and_inherit_train(GridDataset, self.inherited_params) self._check_features() @@ -533,9 +554,9 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals try: fname, mol = self.index_entries[0] - except IndexError as exc: - raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from exc - with h5py.File(fname, 'r') as f5: + except IndexError as e: + raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from e + with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: @@ -552,40 +573,37 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.features_dict[targets.VALUES] = self.target def _check_features(self): - """Checks if the required features exist""" - + """Checks if the required features exist.""" hdf5_path = self.hdf5_paths[0] # read available features with h5py.File(hdf5_path, "r") as f: - mol_key = list(f.keys())[0] + mol_key = next(iter(f.keys())) if isinstance(self.features, list): - self.features = [GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name).group(1) - if GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name) is not None - else feature_name for feature_name in self.features] # be sure to remove the dimension number suffix - self.features = list(set(self.features)) # remove duplicates + self.features = [ + GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name).group(1) + if GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name) is not None + else feature_name + for feature_name in self.features + ] # remove the dimension number suffix + self.features = list(set(self.features)) # remove duplicates available_features = list(f[f"{mol_key}/{gridstorage.MAPPED_FEATURES}"].keys()) - available_features = [key for key in available_features if key[0] != '_'] # ignore metafeatures + available_features = [key for key in available_features if key[0] != "_"] # ignore metafeatures hdf5_matching_feature_names = [] # feature names that match with the requested list of names unpartial_feature_names = [] # feature names without their dimension number suffix for feature_name in available_features: - partial_feature_match = GRID_PARTIAL_FEATURE_NAME_PATTERN.match(feature_name) if partial_feature_match is not None: # there's a dimension number in the feature name - unpartial_feature_name = partial_feature_match.group(1) if self.features == "all" or (isinstance(self.features, list) and unpartial_feature_name in self.features): - hdf5_matching_feature_names.append(feature_name) unpartial_feature_names.append(unpartial_feature_name) else: # no numbers, it's a one-dimensional feature name - if self.features == "all" or (isinstance(self.features, list) and feature_name in self.features): - hdf5_matching_feature_names.append(feature_name) unpartial_feature_names.append(feature_name) @@ -611,11 +629,12 @@ def _check_features(self): # raise error if any features are missing if len(missing_features) > 0: raise ValueError( - f"Not all features could be found in the file {hdf5_path} under entry {mol_key}.\ - \nMissing features are: {missing_features} \ - \nCheck feature_modules passed to the preprocess function. \ - \nProbably, the feature wasn't generated during the preprocessing step. \ - Available features: {available_features}") + f"Not all features could be found in the file {hdf5_path} under entry {mol_key}.\n\t" + f"Missing features are: {missing_features}.\n\t" + "Check feature_modules passed to the preprocess function.\n\t" + "Probably, the feature wasn't generated during the preprocessing step.\n\t" + f"Available features: {available_features}" + ) def get(self, idx: int) -> Data: """Gets one grid item from its unique index. @@ -626,7 +645,6 @@ def get(self, idx: int) -> Data: Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, entry_names. """ - file_path, entry_name = self.index_entries[idx] return self.load_one_grid(file_path, entry_name) @@ -640,36 +658,34 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, entry_names. """ - - feature_data = [] - - with h5py.File(hdf5_path, 'r') as hdf5_file: + with h5py.File(hdf5_path, "r") as hdf5_file: grp = hdf5_file[entry_name] mapped_features_group = grp[gridstorage.MAPPED_FEATURES] - for feature_name in self.features: - if feature_name[0] != '_': # ignore metafeatures - feature_data.append(mapped_features_group[feature_name][:]) - x=torch.tensor(np.expand_dims(np.array(feature_data), axis=0), dtype=torch.float) + + feature_data = [mapped_features_group[feature_name][:] for feature_name in self.features if feature_name[0] != "_"] + x = torch.tensor(np.expand_dims(np.array(feature_data), axis=0), dtype=torch.float) # target if self.target is None: y = None + elif targets.VALUES in grp and self.target in grp[targets.VALUES]: + y = torch.tensor([grp[targets.VALUES][self.target][()]], dtype=torch.float) + + if self.task == targets.REGRESS and self.target_transform is True: + y = torch.sigmoid(torch.log(y)) + elif self.task is not targets.REGRESS and self.target_transform is True: + raise ValueError( + f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' + ) else: - if targets.VALUES in grp and self.target in grp[targets.VALUES]: - y = torch.tensor([grp[targets.VALUES][self.target][()]], dtype=torch.float) - - if self.task == targets.REGRESS and self.target_transform is True: - y = torch.sigmoid(torch.log(y)) - elif self.task is not targets.REGRESS and self.target_transform is True: - raise ValueError(f"Sigmoid transformation is not possible for {self.task} tasks. \ - Please change `task` to \"regress\" or set `target_transform` to `False`.") - else: - y = None - possible_targets = grp[targets.VALUES].keys() - if self.train_source is None: - raise ValueError(f"Target {self.target} missing in entry {entry_name} in file {hdf5_path}, possible targets are {possible_targets}." + - "\n Use the query class to add more target values to input data.") + y = None + possible_targets = grp[targets.VALUES].keys() + if self.train_source is None: + raise ValueError( + f"Target {self.target} missing in entry {entry_name} in file {hdf5_path}, possible targets are {possible_targets}.\n\t" + "Use the query class to add more target values to input data." + ) # Wrap up the data in this object, for the collate_fn to handle it properly: data = Data(x=x, y=y) @@ -679,7 +695,7 @@ def load_one_grid(self, hdf5_path: str, entry_name: str) -> Data: class GraphDataset(DeeprankDataset): - def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-locals + def __init__( self, hdf5_path: str | list, subset: list[str] | None = None, @@ -765,15 +781,21 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local check_integrity (bool, optional): Whether to check the integrity of the hdf5 files. Defaults to True. """ - - super().__init__(hdf5_path, subset, train_source, target, target_transform, - target_filter, task, classes, use_tqdm, root, check_integrity) - - self.default_vars = { - k: v.default - for k, v in inspect.signature(self.__init__).parameters.items() - if v.default is not inspect.Parameter.empty - } + super().__init__( + hdf5_path, + subset, + train_source, + target, + target_transform, + target_filter, + task, + classes, + use_tqdm, + root, + check_integrity, + ) + + self.default_vars = {k: v.default for k, v in inspect.signature(self.__init__).parameters.items() if v.default is not inspect.Parameter.empty} self.default_vars["classes_to_index"] = None self.node_features = node_features self.edge_features = edge_features @@ -782,8 +804,16 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local self.features_transform = features_transform if train_source is not None: - self.inherited_params = ["node_features", "edge_features", "features_transform", "target", - "target_transform", "task", "classes", "classes_to_index"] + self.inherited_params = [ + "node_features", + "edge_features", + "features_transform", + "target", + "target_transform", + "task", + "classes", + "classes_to_index", + ] self._check_and_inherit_train(GraphDataset, self.inherited_params) self._check_features() @@ -793,9 +823,9 @@ def __init__( # noqa: MC0001, pylint: disable=too-many-arguments, too-many-local try: fname, mol = self.index_entries[0] - except IndexError as exc: - raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from exc - with h5py.File(fname, 'r') as f5: + except IndexError as e: + raise IndexError("No entries found in the dataset. Please check the dataset parameters.") from e + with h5py.File(fname, "r") as f5: grp = f5[mol] possible_targets = grp[targets.VALUES].keys() if self.target is None: @@ -834,11 +864,10 @@ def get(self, idx: int) -> Data: Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, edge_index, edge_attr, pos, entry_names. """ - fname, mol = self.index_entries[idx] return self.load_one_graph(fname, mol) - def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disable = too-many-locals # noqa: MC0001 + def load_one_graph(self, fname: str, entry_name: str) -> Data: # noqa: PLR0915 (too-many-statements) """Loads one graph. Args: @@ -848,50 +877,48 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl Returns: :class:`torch_geometric.data.data.Data`: item with tensors x, y if present, edge_index, edge_attr, pos, entry_names. """ - - with h5py.File(fname, 'r') as f5: + with h5py.File(fname, "r") as f5: grp = f5[entry_name] # node features if len(self.node_features) > 0: node_data = () for feat in self.node_features: - # resetting transformation and standardization for each feature transform = None standard = None - if feat[0] != '_': # ignore metafeatures + if feat[0] != "_": # ignore metafeatures vals = grp[f"{Nfeat.NODE}/{feat}"][()] # get feat transformation and standardization - if (self.features_transform is not None): - transform = self.features_transform.get('all', {}).get('transform') - standard = self.features_transform.get('all', {}).get('standardize') + if self.features_transform is not None: + transform = self.features_transform.get("all", {}).get("transform") + standard = self.features_transform.get("all", {}).get("standardize") # if no transformation is set for all features, check if one is set for the current feature if (transform is None) and (feat in self.features_transform): - transform = self.features_transform.get(feat, {}).get('transform') + transform = self.features_transform.get(feat, {}).get("transform") # if no standardization is set for all features, check if one is set for the current feature if (standard is None) and (feat in self.features_transform): - standard = self.features_transform.get(feat, {}).get('standardize') + standard = self.features_transform.get(feat, {}).get("standardize") # apply transformation if transform: with warnings.catch_warnings(record=True) as w: vals = transform(vals) - if (len(w) > 0): - raise ValueError(f"Invalid value occurs in {entry_name}, file {fname}," - f"when applying {transform} for feature {feat}." - f"Please change the transformation function for {feat}.") + if len(w) > 0: + raise ValueError( + f"Invalid value occurs in {entry_name}, file {fname},when applying {transform} for feature {feat}.\n\t" + f"Please change the transformation function for {feat}." + ) - if vals.ndim == 1: # features with only one channel + if vals.ndim == 1: # features with only one channel vals = vals.reshape(-1, 1) if standard: - vals = (vals-self.means[feat])/self.devs[feat] - else: - if standard: - reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] - reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] - vals = (vals - reshaped_mean)/reshaped_dev + vals = (vals - self.means[feat]) / self.devs[feat] + elif standard: + reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] + reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] + vals = (vals - reshaped_mean) / reshaped_dev node_data += (vals,) x = torch.tensor(np.hstack(node_data), dtype=torch.float) else: @@ -913,42 +940,41 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl if len(self.edge_features) > 0: edge_data = () for feat in self.edge_features: - # resetting transformation and standardization for each feature transform = None standard = None - if feat[0] != '_': # ignore metafeatures + if feat[0] != "_": # ignore metafeatures vals = grp[f"{Efeat.EDGE}/{feat}"][()] # get feat transformation and standardization - if (self.features_transform is not None): - transform = self.features_transform.get('all', {}).get('transform') - standard = self.features_transform.get('all', {}).get('standardize') + if self.features_transform is not None: + transform = self.features_transform.get("all", {}).get("transform") + standard = self.features_transform.get("all", {}).get("standardize") # if no transformation is set for all features, check if one is set for the current feature if (transform is None) and (feat in self.features_transform): - transform = self.features_transform.get(feat, {}).get('transform') + transform = self.features_transform.get(feat, {}).get("transform") # if no standardization is set for all features, check if one is set for the current feature if (standard is None) and (feat in self.features_transform): - standard = self.features_transform.get(feat, {}).get('standardize') + standard = self.features_transform.get(feat, {}).get("standardize") # apply transformation if transform: with warnings.catch_warnings(record=True) as w: vals = transform(vals) - if (len(w) > 0): - raise ValueError(f"Invalid value occurs in {entry_name}, file {fname}," - f"when applying {transform} for feature {feat}." - f"Please change the transformation function for {feat}.") + if len(w) > 0: + raise ValueError( + f"Invalid value occurs in {entry_name}, file {fname}, when applying {transform} for feature {feat}.\n\t" + f"Please change the transformation function for {feat}." + ) if vals.ndim == 1: vals = vals.reshape(-1, 1) if standard: - vals = (vals-self.means[feat])/self.devs[feat] - else: - if standard: - reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] - reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] - vals = (vals - reshaped_mean)/reshaped_dev + vals = (vals - self.means[feat]) / self.devs[feat] + elif standard: + reshaped_mean = [mean_value for mean_key, mean_value in self.means.items() if feat in mean_key] + reshaped_dev = [dev_value for dev_key, dev_value in self.devs.items() if feat in dev_key] + vals = (vals - reshaped_mean) / reshaped_dev edge_data += (vals,) edge_data = np.hstack(edge_data) edge_data = np.vstack((edge_data, edge_data)) @@ -959,22 +985,24 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl # target if self.target is None: y = None - else: - if targets.VALUES in grp and self.target in grp[targets.VALUES]: - y = torch.tensor([grp[f"{targets.VALUES}/{self.target}"][()]], dtype=torch.float).contiguous() + elif targets.VALUES in grp and self.target in grp[targets.VALUES]: + y = torch.tensor([grp[f"{targets.VALUES}/{self.target}"][()]], dtype=torch.float).contiguous() - if self.task == targets.REGRESS and self.target_transform is True: - y = torch.sigmoid(torch.log(y)) - elif self.task is not targets.REGRESS and self.target_transform is True: - raise ValueError(f"Sigmoid transformation is not possible for {self.task} tasks. \ - Please change `task` to \"regress\" or set `target_transform` to `False`.") + if self.task == targets.REGRESS and self.target_transform is True: + y = torch.sigmoid(torch.log(y)) + elif self.task is not targets.REGRESS and self.target_transform is True: + raise ValueError( + f'Sigmoid transformation not possible for {self.task} tasks. Please change `task` to "regress" or set `target_transform` to `False`.' + ) - else: - y = None - possible_targets = grp[targets.VALUES].keys() - if self.train_source is None: - raise ValueError(f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}." + - "\n Use the query class to add more target values to input data.") + else: + y = None + possible_targets = grp[targets.VALUES].keys() + if self.train_source is None: + raise ValueError( + f"Target {self.target} missing in entry {entry_name} in file {fname}, possible targets are {possible_targets}.\n\t" + "Use the query class to add more target values to input data." + ) # positions pos = torch.tensor(grp[f"{Nfeat.NODE}/{Nfeat.POSITION}/"][()], dtype=torch.float).contiguous() @@ -982,22 +1010,21 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl # cluster cluster0 = None cluster1 = None - if self.clustering_method is not None: - if 'clustering' in grp.keys(): - if self.clustering_method in grp["clustering"].keys(): - if ( - "depth_0" in grp[f"clustering/{self.clustering_method}"].keys() and - "depth_1" in grp[f"clustering/{self.clustering_method}"].keys() - ): - - cluster0 = torch.tensor( - grp["clustering/" + self.clustering_method + "/depth_0"][()], dtype=torch.long) - cluster1 = torch.tensor( - grp["clustering/" + self.clustering_method + "/depth_1"][()], dtype=torch.long) - else: - _log.warning("no clusters detected") + if self.clustering_method is not None and "clustering" in grp: + if self.clustering_method in grp["clustering"]: + if "depth_0" in grp[f"clustering/{self.clustering_method}"] and "depth_1" in grp[f"clustering/{self.clustering_method}"]: + cluster0 = torch.tensor( + grp["clustering/" + self.clustering_method + "/depth_0"][()], + dtype=torch.long, + ) + cluster1 = torch.tensor( + grp["clustering/" + self.clustering_method + "/depth_1"][()], + dtype=torch.long, + ) else: - _log.warning(f"no clustering/{self.clustering_method} detected") + _log.warning("no clusters detected") + else: + _log.warning(f"no clustering/{self.clustering_method} detected") # load data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos) @@ -1009,18 +1036,18 @@ def load_one_graph(self, fname: str, entry_name: str) -> Data: # pylint: disabl return data - def _check_features(self): #pylint: disable=too-many-branches - """Checks if the required features exist""" + def _check_features(self): + """Checks if the required features exist.""" f = h5py.File(self.hdf5_paths[0], "r") - mol_key = list(f.keys())[0] + mol_key = next(iter(f.keys())) # read available node features self.available_node_features = list(f[f"{mol_key}/{Nfeat.NODE}/"].keys()) - self.available_node_features = [key for key in self.available_node_features if key[0] != '_'] # ignore metafeatures + self.available_node_features = [key for key in self.available_node_features if key[0] != "_"] # ignore metafeatures # read available edge features self.available_edge_features = list(f[f"{mol_key}/{Efeat.EDGE}/"].keys()) - self.available_edge_features = [key for key in self.available_edge_features if key[0] != '_'] # ignore metafeatures + self.available_edge_features = [key for key in self.available_edge_features if key[0] != "_"] # ignore metafeatures f.close() @@ -1059,8 +1086,10 @@ def _check_features(self): #pylint: disable=too-many-branches # raise error if any features are missing if missing_node_features + missing_edge_features: miss_node_error, miss_edge_error = "", "" - _log.info("\nCheck feature_modules passed to the preprocess function.\ - Probably, the feature wasn't generated during the preprocessing step.") + _log.info( + "\nCheck feature_modules passed to the preprocess function.\ + Probably, the feature wasn't generated during the preprocessing step." + ) if missing_node_features: _log.info(f"\nAvailable node features: {self.available_node_features}\n") miss_node_error = f"\nMissing node features: {missing_node_features} \ @@ -1070,18 +1099,14 @@ def _check_features(self): #pylint: disable=too-many-branches miss_edge_error = f"\nMissing edge features: {missing_edge_features} \ \nAvailable edge features: {self.available_edge_features}" raise ValueError( - f"Not all features could be found in the file {self.hdf5_paths[0]}.\ - \nCheck feature_modules passed to the preprocess function. \ - \nProbably, the feature wasn't generated during the preprocessing step. \ - {miss_node_error}{miss_edge_error}") + f"Not all features could be found in the file {self.hdf5_paths[0]}.\n\t" + "Check feature_modules passed to the preprocess function.\n\t" + "Probably, the feature wasn't generated during the preprocessing step.\n\t" + f"{miss_node_error}{miss_edge_error}" + ) -def save_hdf5_keys( - f_src_path: str, - src_ids: list[str], - f_dest_path: str, - hardcopy = False - ): +def save_hdf5_keys(f_src_path: str, src_ids: list[str], f_dest_path: str, hardcopy=False): """Save references to keys in src_ids in a new .HDF5 file. Args: @@ -1096,9 +1121,9 @@ def save_hdf5_keys( if not all(isinstance(d, str) for d in src_ids): raise TypeError("data_ids should be a list containing strings.") - with h5py.File(f_dest_path,'w') as f_dest, h5py.File(f_src_path,'r') as f_src: + with h5py.File(f_dest_path, "w") as f_dest, h5py.File(f_src_path, "r") as f_src: for key in src_ids: if hardcopy: - f_src.copy(f_src[key],f_dest) + f_src.copy(f_src[key], f_dest) else: f_dest[key] = h5py.ExternalLink(f_src_path, "/" + key) diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index 9906e2a32..b49a96c77 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -38,290 +38,312 @@ "Alanine", "ALA", "A", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 1, - mass = 71.1, - pI = 6.00, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 0) + charge=0, + polarity=Polarity.NONPOLAR, + size=1, + mass=71.1, + pI=6.00, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=0, +) cysteine = AminoAcid( "Cysteine", "CYS", "C", - charge = 0, - polarity = Polarity.POLAR, # source 3: "special case"; source 5: nonpolar + charge=0, + polarity=Polarity.POLAR, # source 3: "special case"; source 5: nonpolar # polarity of C is generally considered ambiguous: https://chemistry.stackexchange.com/questions/143142/why-is-the-amino-acid-cysteine-classified-as-polar - size = 2, - mass = 103.2, - pI = 5.07, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 1) + size=2, + mass=103.2, + pI=5.07, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=1, +) selenocysteine = AminoAcid( "Selenocysteine", "SEC", "U", - charge = 0, - polarity = Polarity.POLAR, # source 3: "special case" - size = 2, # source: https://en.wikipedia.org/wiki/Selenocysteine - mass = 150.0, # only from source 3 - pI = 5.47, # only from source 3 - hydrogen_bond_donors = 1, # unconfirmed - hydrogen_bond_acceptors = 2, # unconfirmed - index = cysteine.index) + charge=0, + polarity=Polarity.POLAR, # source 3: "special case" + size=2, # source: https://en.wikipedia.org/wiki/Selenocysteine + mass=150.0, # only from source 3 + pI=5.47, # only from source 3 + hydrogen_bond_donors=1, # unconfirmed + hydrogen_bond_acceptors=2, # unconfirmed + index=cysteine.index, +) aspartate = AminoAcid( "Aspartate", "ASP", "D", - charge = -1, - polarity = Polarity.NEGATIVE, - size = 4, - mass = 115.1, - pI = 2.77, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 4, - index = 2) + charge=-1, + polarity=Polarity.NEGATIVE, + size=4, + mass=115.1, + pI=2.77, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=4, + index=2, +) glutamate = AminoAcid( "Glutamate", "GLU", "E", - charge = -1, - polarity = Polarity.NEGATIVE, - size = 5, - mass = 129.1, - pI = 3.22, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 4, - index = 3) + charge=-1, + polarity=Polarity.NEGATIVE, + size=5, + mass=129.1, + pI=3.22, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=4, + index=3, +) phenylalanine = AminoAcid( "Phenylalanine", "PHE", "F", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 7, - mass = 147.2, - pI = 5.48, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 4) + charge=0, + polarity=Polarity.NONPOLAR, + size=7, + mass=147.2, + pI=5.48, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=4, +) glycine = AminoAcid( "Glycine", "GLY", "G", - charge = 0, - polarity = Polarity.NONPOLAR, # source 3: "special case" - size = 0, - mass = 57.1, - pI = 5.97, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 5) + charge=0, + polarity=Polarity.NONPOLAR, # source 3: "special case" + size=0, + mass=57.1, + pI=5.97, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=5, +) histidine = AminoAcid( "Histidine", "HIS", "H", - charge = 1, - polarity = Polarity.POSITIVE, - size = 6, - mass = 137.1, - pI = 7.59, - hydrogen_bond_donors = 1, - hydrogen_bond_acceptors = 1, + charge=1, + polarity=Polarity.POSITIVE, + size=6, + mass=137.1, + pI=7.59, + hydrogen_bond_donors=1, + hydrogen_bond_acceptors=1, # both position 7 and 10 can serve as either donor or acceptor (depending on tautomer), but any single His will have exactly one donor and one acceptor # (see https://foldit.fandom.com/wiki/Histidine) - index = 6) + index=6, +) isoleucine = AminoAcid( "Isoleucine", "ILE", "I", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 4, - mass = 113.2, - pI = 6.02, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 7) + charge=0, + polarity=Polarity.NONPOLAR, + size=4, + mass=113.2, + pI=6.02, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=7, +) lysine = AminoAcid( "Lysine", "LYS", "K", - charge = 1, - polarity = Polarity.POSITIVE, - size = 5, - mass = 128.2, - pI = 9.74, # 9.60 in source 3 - hydrogen_bond_donors = 3, - hydrogen_bond_acceptors = 0, - index = 8) + charge=1, + polarity=Polarity.POSITIVE, + size=5, + mass=128.2, + pI=9.74, # 9.60 in source 3 + hydrogen_bond_donors=3, + hydrogen_bond_acceptors=0, + index=8, +) pyrrolysine = AminoAcid( "Pyrrolysine", "PYL", "O", - charge = 0, # unconfirmed - polarity = Polarity.POLAR, # based on having both H-bond donors and acceptors - size = 13, # source: https://en.wikipedia.org/wiki/Pyrrolysine - mass = 255.32, # from source 3 - pI = 7.394, # rough estimate from https://rstudio-pubs-static.s3.amazonaws.com/846259_7a9236df54e6410a972621590ecdcfcb.html - hydrogen_bond_donors = 1, # unconfirmed - hydrogen_bond_acceptors = 4, # unconfirmed - index = lysine.index) + charge=0, # unconfirmed + polarity=Polarity.POLAR, # based on having both H-bond donors and acceptors + size=13, # source: https://en.wikipedia.org/wiki/Pyrrolysine + mass=255.32, # from source 3 + pI=7.394, # rough estimate from https://rstudio-pubs-static.s3.amazonaws.com/846259_7a9236df54e6410a972621590ecdcfcb.html + hydrogen_bond_donors=1, # unconfirmed + hydrogen_bond_acceptors=4, # unconfirmed + index=lysine.index, +) leucine = AminoAcid( "Leucine", "LEU", "L", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 4, - mass = 113.2, - pI = 5.98, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 9) + charge=0, + polarity=Polarity.NONPOLAR, + size=4, + mass=113.2, + pI=5.98, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=9, +) methionine = AminoAcid( "Methionine", "MET", "M", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 4, - mass = 131.2, - pI = 5.74, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 10) + charge=0, + polarity=Polarity.NONPOLAR, + size=4, + mass=131.2, + pI=5.74, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=10, +) asparagine = AminoAcid( "Asparagine", "ASN", "N", - charge = 0, - polarity = Polarity.POLAR, - size = 4, - mass = 114.1, - pI = 5.41, - hydrogen_bond_donors = 2, - hydrogen_bond_acceptors = 2, - index = 11) + charge=0, + polarity=Polarity.POLAR, + size=4, + mass=114.1, + pI=5.41, + hydrogen_bond_donors=2, + hydrogen_bond_acceptors=2, + index=11, +) proline = AminoAcid( "Proline", "PRO", "P", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 3, - mass = 97.1, - pI = 6.30, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 12) + charge=0, + polarity=Polarity.NONPOLAR, + size=3, + mass=97.1, + pI=6.30, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=12, +) glutamine = AminoAcid( "Glutamine", "GLN", "Q", - charge = 0, - polarity = Polarity.POLAR, - size = 5, - mass = 128.1, - pI = 5.65, - hydrogen_bond_donors = 2, - hydrogen_bond_acceptors = 2, - index = 13) + charge=0, + polarity=Polarity.POLAR, + size=5, + mass=128.1, + pI=5.65, + hydrogen_bond_donors=2, + hydrogen_bond_acceptors=2, + index=13, +) arginine = AminoAcid( "Arginine", "ARG", "R", - charge = 1, - polarity = Polarity.POSITIVE, - size = 7, - mass = 156.2, - pI = 10.76, - hydrogen_bond_donors = 5, - hydrogen_bond_acceptors = 0, - index = 14) + charge=1, + polarity=Polarity.POSITIVE, + size=7, + mass=156.2, + pI=10.76, + hydrogen_bond_donors=5, + hydrogen_bond_acceptors=0, + index=14, +) serine = AminoAcid( "Serine", "SER", "S", - charge = 0, - polarity = Polarity.POLAR, - size = 2, - mass = 87.1, - pI = 5.68, - hydrogen_bond_donors = 1, - hydrogen_bond_acceptors = 2, - index = 15) + charge=0, + polarity=Polarity.POLAR, + size=2, + mass=87.1, + pI=5.68, + hydrogen_bond_donors=1, + hydrogen_bond_acceptors=2, + index=15, +) threonine = AminoAcid( "Threonine", "THR", "T", - charge = 0, - polarity = Polarity.POLAR, - size = 3, - mass = 101.1, - pI = 5.60, # 6.16 in source 2 - hydrogen_bond_donors = 1, - hydrogen_bond_acceptors = 2, - index = 16) + charge=0, + polarity=Polarity.POLAR, + size=3, + mass=101.1, + pI=5.60, # 6.16 in source 2 + hydrogen_bond_donors=1, + hydrogen_bond_acceptors=2, + index=16, +) valine = AminoAcid( "Valine", "VAL", "V", - charge = 0, - polarity = Polarity.NONPOLAR, - size = 3, - mass = 99.1, - pI = 5.96, - hydrogen_bond_donors = 0, - hydrogen_bond_acceptors = 0, - index = 17) + charge=0, + polarity=Polarity.NONPOLAR, + size=3, + mass=99.1, + pI=5.96, + hydrogen_bond_donors=0, + hydrogen_bond_acceptors=0, + index=17, +) tryptophan = AminoAcid( "Tryptophan", "TRP", "W", - charge = 0, - polarity = Polarity.NONPOLAR, # source 4: polar - size = 10, - mass = 186.2, - pI = 5.89, - hydrogen_bond_donors = 1, - hydrogen_bond_acceptors = 0, - index = 18) + charge=0, + polarity=Polarity.NONPOLAR, # source 4: polar + size=10, + mass=186.2, + pI=5.89, + hydrogen_bond_donors=1, + hydrogen_bond_acceptors=0, + index=18, +) tyrosine = AminoAcid( "Tyrosine", "TYR", "Y", - charge = -0., - polarity = Polarity.POLAR, # source 3: nonpolar - size = 8, - mass = 163.2, - pI = 5.66, - hydrogen_bond_donors = 1, - hydrogen_bond_acceptors = 1, - index = 19) + charge=-0.0, + polarity=Polarity.POLAR, # source 3: nonpolar + size=8, + mass=163.2, + pI=5.66, + hydrogen_bond_donors=1, + hydrogen_bond_acceptors=1, + index=19, +) # Including selenocysteine and pyrrolysine in the future will require some work to be done on the package. @@ -348,22 +370,23 @@ valine, # selenocysteine, # pyrrolysine, - ] +] amino_acids_by_code = {amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids} amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids} amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids} + def convert_aa_nomenclature(aa: str, output_type: int | None = None): try: if len(aa) == 1: - aa: AminoAcid = [entry for entry in amino_acids if entry.one_letter_code.lower() == aa.lower()][0] + aa: AminoAcid = next(entry for entry in amino_acids if entry.one_letter_code.lower() == aa.lower()) elif len(aa) == 3: - aa: AminoAcid = [entry for entry in amino_acids if entry.three_letter_code.lower() == aa.lower()][0] + aa: AminoAcid = next(entry for entry in amino_acids if entry.three_letter_code.lower() == aa.lower()) else: - aa: AminoAcid = [entry for entry in amino_acids if entry.name.lower() == aa.lower()][0] + aa: AminoAcid = next(entry for entry in amino_acids if entry.name.lower() == aa.lower()) except IndexError as e: - raise ValueError(f'{aa} is not a valid amino acid.') from e + raise ValueError(f"{aa} is not a valid amino acid.") from e if not output_type: return aa.name @@ -371,4 +394,4 @@ def convert_aa_nomenclature(aa: str, output_type: int | None = None): return aa.three_letter_code if output_type == 1: return aa.one_letter_code - raise ValueError(f'output_type {output_type} not recognized. Must be set to None (amino acid name), 1 (one letter code), or 3 (three letter code).') + raise ValueError(f"output_type {output_type} not recognized. Must be set to None (amino acid name), 1 (one letter code), or 3 (three letter code).") diff --git a/deeprank2/domain/edgestorage.py b/deeprank2/domain/edgestorage.py index dd2b6d1e5..ac739dee3 100644 --- a/deeprank2/domain/edgestorage.py +++ b/deeprank2/domain/edgestorage.py @@ -6,11 +6,11 @@ INDEX = "_index" ## generic features -DISTANCE = "distance" # float; former FEATURENAME_EDGEDISTANCE -SAMECHAIN = "same_chain" # bool; former FEATURE_EDGE_INTERACTIONTYPE & FEATURENAME_EDGESAMECHAIN -SAMERES = "same_res" # bool +DISTANCE = "distance" # float; former FEATURENAME_EDGEDISTANCE +SAMECHAIN = "same_chain" # bool; former FEATURE_EDGE_INTERACTIONTYPE & FEATURENAME_EDGESAMECHAIN +SAMERES = "same_res" # bool ## interactions -COVALENT = "covalent" # bool; former FEATURENAME_COVALENT -ELEC = "electrostatic" # float; former FEATURENAME_EDGECOULOMB -VDW = "vanderwaals" # float; former FEATURENAME_EDGEVANDERWAALS +COVALENT = "covalent" # bool; former FEATURENAME_COVALENT +ELEC = "electrostatic" # float; former FEATURENAME_EDGECOULOMB +VDW = "vanderwaals" # float; former FEATURENAME_EDGEVANDERWAALS diff --git a/deeprank2/domain/gridstorage.py b/deeprank2/domain/gridstorage.py index 14ec7a2dc..62aa63734 100644 --- a/deeprank2/domain/gridstorage.py +++ b/deeprank2/domain/gridstorage.py @@ -1,3 +1 @@ - - MAPPED_FEATURES = "mapped_features" diff --git a/deeprank2/domain/losstypes.py b/deeprank2/domain/losstypes.py index a04dad057..74510ba61 100644 --- a/deeprank2/domain/losstypes.py +++ b/deeprank2/domain/losstypes.py @@ -1,31 +1,41 @@ from torch import nn -regression_losses = (nn.L1Loss, - nn.SmoothL1Loss, - nn.MSELoss, - nn.HuberLoss) +regression_losses = ( + nn.L1Loss, + nn.SmoothL1Loss, + nn.MSELoss, + nn.HuberLoss, +) -binary_classification_losses = (nn.SoftMarginLoss, - nn.BCELoss, - nn.BCEWithLogitsLoss) +binary_classification_losses = ( + nn.SoftMarginLoss, + nn.BCELoss, + nn.BCEWithLogitsLoss, +) -multi_classification_losses = (nn.CrossEntropyLoss, - nn.NLLLoss, - nn.PoissonNLLLoss, - nn.GaussianNLLLoss, - nn.KLDivLoss, - nn.MultiLabelMarginLoss, - nn.MultiLabelSoftMarginLoss) +multi_classification_losses = ( + nn.CrossEntropyLoss, + nn.NLLLoss, + nn.PoissonNLLLoss, + nn.GaussianNLLLoss, + nn.KLDivLoss, + nn.MultiLabelMarginLoss, + nn.MultiLabelSoftMarginLoss, +) -other_losses = (nn.HingeEmbeddingLoss, - nn.CosineEmbeddingLoss, - nn.MarginRankingLoss, - nn.TripletMarginLoss, - nn.CTCLoss) +other_losses = ( + nn.HingeEmbeddingLoss, + nn.CosineEmbeddingLoss, + nn.MarginRankingLoss, + nn.TripletMarginLoss, + nn.CTCLoss, +) classification_losses = multi_classification_losses + binary_classification_losses -classification_tested = (nn.CrossEntropyLoss, - nn.NLLLoss, - nn.BCELoss, - nn.BCEWithLogitsLoss) +classification_tested = ( + nn.CrossEntropyLoss, + nn.NLLLoss, + nn.BCELoss, + nn.BCEWithLogitsLoss, +) diff --git a/deeprank2/domain/nodestorage.py b/deeprank2/domain/nodestorage.py index e214f1520..bce77488b 100644 --- a/deeprank2/domain/nodestorage.py +++ b/deeprank2/domain/nodestorage.py @@ -3,8 +3,8 @@ ## metafeatures NAME = "_name" -CHAINID = "_chain_id" # str; former FEATURENAME_CHAIN (was not assigned, but supposedly numeric, now a str) -POSITION = "_position" # list[3xfloat]; former FEATURENAME_POSITION +CHAINID = "_chain_id" # str; former FEATURENAME_CHAIN (was not assigned, but supposedly numeric, now a str) +POSITION = "_position" # list[3xfloat]; former FEATURENAME_POSITION ## atom core features ATOMTYPE = "atom_type" @@ -12,50 +12,50 @@ PDBOCCUPANCY = "pdb_occupancy" ## residue core features -RESTYPE = "res_type" # AminoAcid object; former FEATURENAME_AMINOACID -RESCHARGE = "res_charge" # float(<0); former FEATURENAME_CHARGE (was not assigned) -POLARITY = "polarity" # Polarity object; former FEATURENAME_POLARITY -RESSIZE = "res_size" # int; former FEATURENAME_SIZE +RESTYPE = "res_type" # AminoAcid object; former FEATURENAME_AMINOACID +RESCHARGE = "res_charge" # float(<0); former FEATURENAME_CHARGE (was not assigned) +POLARITY = "polarity" # Polarity object; former FEATURENAME_POLARITY +RESSIZE = "res_size" # int; former FEATURENAME_SIZE RESMASS = "res_mass" RESPI = "res_pI" -HBDONORS = "hb_donors" # int; former FEATURENAME_HYDROGENBONDDONORS -HBACCEPTORS = "hb_acceptors"# int; former FEATURENAME_HYDROGENBONDACCEPTORS +HBDONORS = "hb_donors" # int; former FEATURENAME_HYDROGENBONDDONORS +HBACCEPTORS = "hb_acceptors" # int; former FEATURENAME_HYDROGENBONDACCEPTORS ## variant residue features -VARIANTRES = "variant_res" # AminoAcid object; former FEATURENAME_VARIANTAMINOACID -DIFFCHARGE = "diff_charge" # float -DIFFSIZE = "diff_size" # int; former FEATURENAME_SIZEDIFFERENCE +VARIANTRES = "variant_res" # AminoAcid object; former FEATURENAME_VARIANTAMINOACID +DIFFCHARGE = "diff_charge" # float +DIFFSIZE = "diff_size" # int; former FEATURENAME_SIZEDIFFERENCE DIFFMASS = "diff_mass" DIFFPI = "diff_pI" -DIFFPOLARITY = "diff_polarity" # [type?]; former FEATURENAME_POLARITYDIFFERENCE -DIFFHBDONORS = "diff_hb_donors" # int; former FEATURENAME_HYDROGENBONDDONORSDIFFERENCE -DIFFHBACCEPTORS = "diff_hb_acceptors" # int; former FEATURENAME_HYDROGENBONDACCEPTORSDIFFERENCE +DIFFPOLARITY = "diff_polarity" # [type?]; former FEATURENAME_POLARITYDIFFERENCE +DIFFHBDONORS = "diff_hb_donors" # int; former FEATURENAME_HYDROGENBONDDONORSDIFFERENCE +DIFFHBACCEPTORS = "diff_hb_acceptors" # int; former FEATURENAME_HYDROGENBONDACCEPTORSDIFFERENCE ## conservation features -PSSM = "pssm" # list[20xint]; former FEATURENAME_PSSM -INFOCONTENT = "info_content" # float; former FEATURENAME_INFORMATIONCONTENT -CONSERVATION = "conservation" # int; former FEATURENAME_PSSMWILDTYPE -DIFFCONSERVATION = "diff_conservation" # int; former FEATURENAME_PSSMDIFFERENCE & FEATURENAME_CONSERVATIONDIFFERENCE +PSSM = "pssm" # list[20xint]; former FEATURENAME_PSSM +INFOCONTENT = "info_content" # float; former FEATURENAME_INFORMATIONCONTENT +CONSERVATION = "conservation" # int; former FEATURENAME_PSSMWILDTYPE +DIFFCONSERVATION = "diff_conservation" # int; former FEATURENAME_PSSMDIFFERENCE & FEATURENAME_CONSERVATIONDIFFERENCE ## protein context features -RESDEPTH = "res_depth" # float; former FEATURENAME_RESIDUEDEPTH -HSE = "hse" # list[3xfloat]; former FEATURENAME_HALFSPHEREEXPOSURE -SASA = "sasa" # float; former FEATURENAME_SASA -BSA = "bsa" # float; former FEATURENAME_BURIEDSURFACEAREA -SECSTRUCT = "sec_struct" #secondary structure +RESDEPTH = "res_depth" # float; former FEATURENAME_RESIDUEDEPTH +HSE = "hse" # list[3xfloat]; former FEATURENAME_HALFSPHEREEXPOSURE +SASA = "sasa" # float; former FEATURENAME_SASA +BSA = "bsa" # float; former FEATURENAME_BURIEDSURFACEAREA +SECSTRUCT = "sec_struct" # secondary structure ## inter-residue contacts (IRCs) -IRC_NONNON = 'irc_nonpolar_nonpolar' # int -IRC_NONPOL = 'irc_nonpolar_polar' # int -IRC_NONNEG = 'irc_nonpolar_negative' # int -IRC_NONPOS = 'irc_nonpolar_positive' # int -IRC_POLPOL = 'irc_polar_polar' # int -IRC_POLNEG = 'irc_polar_negative' # int -IRC_POLPOS = 'irc_polar_positive' # int -IRC_NEGNEG = 'irc_negative_negative' # int -IRC_NEGPOS = 'irc_negative_positive' # int -IRC_POSPOS = 'irc_positive_positive' # int -IRCTOTAL = 'irc_total' # int +IRC_NONNON = "irc_nonpolar_nonpolar" # int +IRC_NONPOL = "irc_nonpolar_polar" # int +IRC_NONNEG = "irc_nonpolar_negative" # int +IRC_NONPOS = "irc_nonpolar_positive" # int +IRC_POLPOL = "irc_polar_polar" # int +IRC_POLNEG = "irc_polar_negative" # int +IRC_POLPOS = "irc_polar_positive" # int +IRC_NEGNEG = "irc_negative_negative" # int +IRC_NEGPOS = "irc_negative_positive" # int +IRC_POSPOS = "irc_positive_positive" # int +IRCTOTAL = "irc_total" # int IRC_FEATURES = [ IRC_NONNON, diff --git a/deeprank2/domain/targetstorage.py b/deeprank2/domain/targetstorage.py index 922731d63..229a69eec 100644 --- a/deeprank2/domain/targetstorage.py +++ b/deeprank2/domain/targetstorage.py @@ -8,11 +8,11 @@ CAPRI = "capri_class" ## regression tasks -IRMSD = 'irmsd' -LRMSD = 'lrmsd' -FNAT = 'fnat' -DOCKQ = 'dockq' +IRMSD = "irmsd" +LRMSD = "lrmsd" +FNAT = "fnat" +DOCKQ = "dockq" ## task names -REGRESS = 'regress' -CLASSIF = 'classif' +REGRESS = "regress" +CLASSIF = "classif" diff --git a/deeprank2/features/components.py b/deeprank2/features/components.py index 73a77b76b..d4844badb 100644 --- a/deeprank2/features/components.py +++ b/deeprank2/features/components.py @@ -10,12 +10,12 @@ _log = logging.getLogger(__name__) -def add_features( # pylint: disable=unused-argument - pdb_path: str, + +def add_features( + pdb_path: str, # noqa: ARG001 (unused argument) graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, ): - for node in graph.nodes: if isinstance(node.id, Residue): residue = node.id @@ -38,7 +38,6 @@ def add_features( # pylint: disable=unused-argument node.features[Nfeat.HBDONORS] = residue.amino_acid.hydrogen_bond_donors node.features[Nfeat.HBACCEPTORS] = residue.amino_acid.hydrogen_bond_acceptors - if single_amino_acid_variant is not None: wildtype = single_amino_acid_variant.wildtype_amino_acid variant = single_amino_acid_variant.variant_amino_acid diff --git a/deeprank2/features/conservation.py b/deeprank2/features/conservation.py index 8ab5ad56e..09e002eb9 100644 --- a/deeprank2/features/conservation.py +++ b/deeprank2/features/conservation.py @@ -1,4 +1,3 @@ - import numpy as np from deeprank2.domain import nodestorage as Nfeat @@ -8,12 +7,11 @@ from deeprank2.utils.graph import Graph -def add_features( # pylint: disable=unused-argument - pdb_path: str, +def add_features( + pdb_path: str, # noqa: ARG001 (unused argument) graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, ): - profile_amino_acid_order = sorted(amino_acids, key=lambda aa: aa.three_letter_code) for node in graph.nodes: diff --git a/deeprank2/features/contact.py b/deeprank2/features/contact.py index 31436bd2d..927c66208 100644 --- a/deeprank2/features/contact.py +++ b/deeprank2/features/contact.py @@ -18,11 +18,14 @@ covalent_cutoff = 2.1 cutoff_13 = 3.6 cutoff_14 = 4.2 +EPSILON0 = 1.0 +COULOMB_CONSTANT = 332.0636 -def _get_nonbonded_energy( #pylint: disable=too-many-locals + +def _get_nonbonded_energy( atoms: list[Atom], distances: NDArray[np.float64], - ) -> tuple [NDArray[np.float64], NDArray[np.float64]]: +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Calculates all pairwise electrostatic (Coulomb) and Van der Waals (Lennard Jones) potential energies between all atoms in the structure. Warning: there's no distance cutoff here. The radius of influence is assumed to infinite. @@ -37,10 +40,7 @@ def _get_nonbonded_energy( #pylint: disable=too-many-locals Tuple [NDArray[np.float64], NDArray[np.float64]]: matrices in same format as `distances` containing all pairwise electrostatic potential energies and all pairwise Van der Waals potential energies """ - # ELECTROSTATIC POTENTIAL - EPSILON0 = 1.0 - COULOMB_CONSTANT = 332.0636 charges = [atomic_forcefield.get_charge(atom) for atom in atoms] E_elec = np.expand_dims(charges, axis=1) * np.expand_dims(charges, axis=0) * COULOMB_CONSTANT / (EPSILON0 * distances) @@ -48,21 +48,20 @@ def _get_nonbonded_energy( #pylint: disable=too-many-locals # calculate main vdw energies sigmas = [atomic_forcefield.get_vanderwaals_parameters(atom).sigma_main for atom in atoms] epsilons = [atomic_forcefield.get_vanderwaals_parameters(atom).epsilon_main for atom in atoms] - mean_sigmas = 0.5 * np.add.outer(sigmas,sigmas) - geomean_eps = np.sqrt(np.multiply.outer(epsilons,epsilons)) # sqrt(eps1*eps2) + mean_sigmas = 0.5 * np.add.outer(sigmas, sigmas) + geomean_eps = np.sqrt(np.multiply.outer(epsilons, epsilons)) # sqrt(eps1*eps2) E_vdw = 4.0 * geomean_eps * ((mean_sigmas / distances) ** 12 - (mean_sigmas / distances) ** 6) # calculate vdw energies for 1-4 pairs sigmas = [atomic_forcefield.get_vanderwaals_parameters(atom).sigma_14 for atom in atoms] epsilons = [atomic_forcefield.get_vanderwaals_parameters(atom).epsilon_14 for atom in atoms] - mean_sigmas = 0.5 * np.add.outer(sigmas,sigmas) - geomean_eps = np.sqrt(np.multiply.outer(epsilons,epsilons)) # sqrt(eps1*eps2) + mean_sigmas = 0.5 * np.add.outer(sigmas, sigmas) + geomean_eps = np.sqrt(np.multiply.outer(epsilons, epsilons)) # sqrt(eps1*eps2) E_vdw_14pairs = 4.0 * geomean_eps * ((mean_sigmas / distances) ** 12 - (mean_sigmas / distances) ** 6) - # Fix energies for close contacts on same chain chains = [atom.residue.chain.id for atom in atoms] - chain_matrix = [[chain_1==chain_2 for chain_2 in chains] for chain_1 in chains] + chain_matrix = [[chain_1 == chain_2 for chain_2 in chains] for chain_1 in chains] pair_14 = np.logical_and(distances < cutoff_14, chain_matrix) pair_13 = np.logical_and(distances < cutoff_13, chain_matrix) @@ -70,16 +69,14 @@ def _get_nonbonded_energy( #pylint: disable=too-many-locals E_vdw[pair_13] = 0 E_elec[pair_13] = 0 - return E_elec, E_vdw -def add_features( # pylint: disable=unused-argument, too-many-locals - pdb_path: str, +def add_features( + pdb_path: str, # noqa: ARG001 (unused argument) graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) ): - # assign each atoms (from all edges) a unique index all_atoms = set() if isinstance(graph.edges[0].id, AtomicContact): @@ -90,11 +87,10 @@ def add_features( # pylint: disable=unused-argument, too-many-locals elif isinstance(graph.edges[0].id, ResidueContact): for edge in graph.edges: contact = edge.id - for atom in (contact.residue1.atoms + contact.residue2.atoms): + for atom in contact.residue1.atoms + contact.residue2.atoms: all_atoms.add(atom) else: - raise TypeError( - f"Unexpected edge type: {type(graph.edges[0].id)}") + raise TypeError(f"Unexpected edge type: {type(graph.edges[0].id)}") all_atoms = list(all_atoms) atom_dict = {atom: i for i, atom in enumerate(all_atoms)} @@ -104,7 +100,10 @@ def add_features( # pylint: disable=unused-argument, too-many-locals warnings.simplefilter("ignore") positions = [atom.position for atom in all_atoms] interatomic_distances = distance_matrix(positions, positions) - interatomic_electrostatic_energy, interatomic_vanderwaals_energy = _get_nonbonded_energy(all_atoms, interatomic_distances) + ( + interatomic_electrostatic_energy, + interatomic_vanderwaals_energy, + ) = _get_nonbonded_energy(all_atoms, interatomic_distances) # assign features for edge in graph.edges: diff --git a/deeprank2/features/exposure.py b/deeprank2/features/exposure.py index 9b86423e6..a48a6b428 100644 --- a/deeprank2/features/exposure.py +++ b/deeprank2/features/exposure.py @@ -17,13 +17,13 @@ _log = logging.getLogger(__name__) -def handle_sigint(sig, frame): # pylint: disable=unused-argument - print('SIGINT received, terminating.') +def handle_sigint(sig, frame): # noqa: ARG001 (unused argument) + print("SIGINT received, terminating.") sys.exit() -def handle_timeout(sig, frame): - raise TimeoutError('Timed out!') +def handle_timeout(sig, frame): # noqa: ARG001 (unused argument) + raise TimeoutError("Timed out!") def space_if_none(value): @@ -32,18 +32,17 @@ def space_if_none(value): return value -def add_features( # pylint: disable=unused-argument +def add_features( pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) ): - signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGALRM, handle_timeout) with warnings.catch_warnings(record=PDBConstructionWarning): parser = PDBParser() - structure = parser.get_structure('_tmp', pdb_path) + structure = parser.get_structure("_tmp", pdb_path) bio_model = structure[0] try: @@ -51,7 +50,7 @@ def add_features( # pylint: disable=unused-argument surface = get_surface(bio_model) signal.alarm(0) except TimeoutError as e: - raise TimeoutError('Bio.PDB.ResidueDepth.get_surface timed out.') from e + raise TimeoutError("Bio.PDB.ResidueDepth.get_surface timed out.") from e # These can only be calculated per residue, not per atom. # So for atomic graphs, every atom gets its residue's value. @@ -67,7 +66,10 @@ def add_features( # pylint: disable=unused-argument bio_residue = bio_model[residue.chain.id][residue.number] node.features[Nfeat.RESDEPTH] = residue_depth(bio_residue, surface) - hse_key = (residue.chain.id, (" ", residue.number, space_if_none(residue.insertion_code))) + hse_key = ( + residue.chain.id, + (" ", residue.number, space_if_none(residue.insertion_code)), + ) if hse_key in hse: node.features[Nfeat.HSE] = np.array(hse[hse_key], dtype=np.float64) diff --git a/deeprank2/features/irc.py b/deeprank2/features/irc.py index 4d7583f9d..6f3ddc8dd 100644 --- a/deeprank2/features/irc.py +++ b/deeprank2/features/irc.py @@ -14,7 +14,7 @@ def _id_from_residue(residue: tuple[str, int, str]) -> str: - """Create and id from pdb2sql rendered residues that is similar to the id of residue nodes + """Create and id from pdb2sql rendered residues that is similar to the id of residue nodes. Args: residue (tuple): Input residue as rendered by pdb2sql: ( str(), int(), str( ) @@ -23,22 +23,20 @@ def _id_from_residue(residue: tuple[str, int, str]) -> str: Returns: str: Output id in form of ''. For example: 'A27'. """ - return residue[0] + str(residue[1]) class _ContactDensity: - """Internal class that holds contact density information for a given residue. - """ + """Internal class that holds contact density information for a given residue.""" def __init__(self, residue: tuple[str, int, str], polarity: Polarity): self.res = residue self.polarity = polarity self.id = _id_from_residue(self.res) self.densities = {pol: 0 for pol in Polarity} - self.densities['total'] = 0 + self.densities["total"] = 0 self.connections = {pol: [] for pol in Polarity} - self.connections['all'] = [] + self.connections["all"] = [] def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, _ContactDensity]: @@ -54,14 +52,14 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, keys: ids of residues in form returned by id_from_residue. items: _ContactDensity objects, containing all contact density information for the residue. """ - residue_contacts: dict[str, _ContactDensity] = {} sql = pdb2sql.interface(pdb_path) pdb2sql_contacts = sql.get_contact_residues( cutoff=cutoff, - chain1=chains[0], chain2=chains[1], - return_contact_pairs=True + chain1=chains[0], + chain2=chains[1], + return_contact_pairs=True, ) for chain1_res, chain2_residues in pdb2sql_contacts.items(): @@ -83,9 +81,9 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, continue # skip keys that are not an amino acid # populate densities and connections for chain1_res - residue_contacts[contact1_id].densities['total'] += 1 + residue_contacts[contact1_id].densities["total"] += 1 residue_contacts[contact1_id].densities[aa2.polarity] += 1 - residue_contacts[contact1_id].connections['all'].append(chain2_res) + residue_contacts[contact1_id].connections["all"].append(chain2_res) residue_contacts[contact1_id].connections[aa2.polarity].append(chain2_res) # add chain2_res to residue_contact dict if it doesn't exist yet @@ -94,9 +92,9 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, residue_contacts[contact2_id] = _ContactDensity(chain2_res, aa2.polarity) # populate densities and connections for chain2_res - residue_contacts[contact2_id].densities['total'] += 1 + residue_contacts[contact2_id].densities["total"] += 1 residue_contacts[contact2_id].densities[aa1.polarity] += 1 - residue_contacts[contact2_id].connections['all'].append(chain1_res) + residue_contacts[contact2_id].connections["all"].append(chain1_res) residue_contacts[contact2_id].connections[aa1.polarity].append(chain1_res) return residue_contacts @@ -107,10 +105,9 @@ def add_features( graph: Graph, single_amino_acid_variant: SingleResidueVariant | None = None, ): - if not single_amino_acid_variant: # VariantQueries do not use this feature polarity_pairs = list(combinations(Polarity, 2)) - polarity_pair_string = [f'irc_{x[0].name.lower()}_{x[1].name.lower()}' for x in polarity_pairs] + polarity_pair_string = [f"irc_{x[0].name.lower()}_{x[1].name.lower()}" for x in polarity_pairs] total_contacts = 0 residue_contacts = get_IRCs(pdb_path, graph.get_all_chains()) @@ -132,7 +129,7 @@ def add_features( # load correct values to IRC features try: - node.features[Nfeat.IRCTOTAL] = residue_contacts[contact_id].densities['total'] + node.features[Nfeat.IRCTOTAL] = residue_contacts[contact_id].densities["total"] for i, pair in enumerate(polarity_pairs): if residue_contacts[contact_id].polarity == pair[0]: node.features[polarity_pair_string[i]] = residue_contacts[contact_id].densities[pair[1]] diff --git a/deeprank2/features/secondary_structure.py b/deeprank2/features/secondary_structure.py index c893f9cfd..d3a2182bf 100644 --- a/deeprank2/features/secondary_structure.py +++ b/deeprank2/features/secondary_structure.py @@ -12,15 +12,15 @@ class DSSPError(Exception): - "Raised if DSSP fails to produce an output" + """Raised if DSSP fails to produce an output.""" class SecondarySctructure(Enum): - "a value to express a secondary a residue's secondary structure type" + """Value to express a secondary a residue's secondary structure type.""" - HELIX = 0 # 'GHI' - STRAND = 1 # 'BE' - COIL = 2 # ' -STP' + HELIX = 0 # 'GHI' + STRAND = 1 # 'BE' + COIL = 2 # ' -STP' @property def onehot(self): @@ -38,45 +38,45 @@ def _get_records(lines: list[str]): def _check_pdb(pdb_path: str): fix_pdb = False - with open(pdb_path, encoding='utf-8') as f: + with open(pdb_path, encoding="utf-8") as f: lines = f.readlines() # check for HEADER firstline = lines[0] - if not firstline.startswith('HEADER'): + if not firstline.startswith("HEADER"): fix_pdb = True - if firstline.startswith('EXPDTA'): - lines = [f'HEADER {firstline}'] + lines[1:] + if firstline.startswith("EXPDTA"): + lines = [f"HEADER {firstline}"] + lines[1:] else: - lines = ['HEADER \n'] + lines + lines = ["HEADER \n", *lines] # check for CRYST1 record existing_records = _get_records(lines) - if 'CRYST1' not in existing_records: + if "CRYST1" not in existing_records: fix_pdb = True - dummy_CRYST1 = 'CRYST1 00.000 00.000 00.000 00.00 00.00 00.00 X 00 00 0 00\n' + dummy_CRYST1 = "CRYST1 00.000 00.000 00.000 00.00 00.00 00.00 X 00 00 0 00\n" lines = [lines[0]] + [dummy_CRYST1] + lines[1:] # check for unnumbered REMARK lines for i, line in enumerate(lines): - if line.startswith('REMARK'): + if line.startswith("REMARK"): try: int(line.split()[1]) except ValueError: fix_pdb = True - lines[i] = f'REMARK 999 {line[7:]}' + lines[i] = f"REMARK 999 {line[7:]}" if fix_pdb: - with open(pdb_path, 'w', encoding='utf-8') as f: + with open(pdb_path, "w", encoding="utf-8") as f: f.writelines(lines) def _classify_secstructure(subtype: str): - if subtype in 'GHI': + if subtype in "GHI": return SecondarySctructure.HELIX - if subtype in 'BE': + if subtype in "BE": return SecondarySctructure.STRAND - if subtype in ' -STP': + if subtype in " -STP": return SecondarySctructure.COIL return None @@ -90,21 +90,21 @@ def _get_secstructure(pdb_path: str) -> dict: Returns: dict: A dictionary containing secondary structure information for each chain and residue. """ - # Execute DSSP and read the output _check_pdb(pdb_path) p = PDBParser(QUIET=True) model = p.get_structure(Path(pdb_path).stem, pdb_path)[0] - # pylint: disable=raise-missing-from try: - dssp = DSSP(model, pdb_path, dssp='mkdssp') - except Exception as e: # improperly formatted pdb files raise: `Exception: DSSP failed to produce an output` - pdb_format_link = 'https://www.wwpdb.org/documentation/file-format-content/format33/sect1.html#Order' - raise DSSPError(f'DSSP has raised the following exception: {e}.\ - \nThis is likely due to an improrperly formatted pdb file: {pdb_path}.\ - \nSee {pdb_format_link} for guidance on how to format your pdb files.\ - \nAlternatively, turn off secondary_structure feature module during QueryCollection.process().') + dssp = DSSP(model, pdb_path, dssp="mkdssp") + except Exception as e: # noqa: BLE001 (blind-except), namely: # improperly formatted pdb files raise: `Exception: DSSP failed to produce an output` + pdb_format_link = "https://www.wwpdb.org/documentation/file-format-content/format33/sect1.html#Order" + raise DSSPError( + f"DSSP has raised the following exception: {e}.\n\t" + f"This is likely due to an improrperly formatted pdb file: {pdb_path}.\n\t" + f"See {pdb_format_link} for guidance on how to format your pdb files.\n\t" + "Alternatively, turn off secondary_structure feature module during QueryCollection.process()." + ) from e chain_ids = [dssp_key[0] for dssp_key in dssp.property_keys] res_numbers = [dssp_key[1][1] for dssp_key in dssp.property_keys] @@ -120,12 +120,11 @@ def _get_secstructure(pdb_path: str) -> dict: return sec_structure_dict -def add_features( # pylint: disable=unused-argument +def add_features( pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) ): - sec_structure_features = _get_secstructure(pdb_path) for node in graph.nodes: @@ -140,9 +139,9 @@ def add_features( # pylint: disable=unused-argument chain_id = residue.chain.id res_num = residue.number - # pylint: disable=raise-missing-from try: node.features[Nfeat.SECSTRUCT] = _classify_secstructure(sec_structure_features[chain_id][res_num]).onehot - except AttributeError: - raise ValueError(f'Unknown secondary structure type ({sec_structure_features[chain_id][res_num]}) ' + - f'detected on chain {chain_id} residues {res_num}.') + except AttributeError as e: + raise ValueError( + f"Unknown secondary structure type ({sec_structure_features[chain_id][res_num]}) detected on chain {chain_id} residues {res_num}." + ) from e diff --git a/deeprank2/features/surfacearea.py b/deeprank2/features/surfacearea.py index 022bbccf7..67cc64920 100644 --- a/deeprank2/features/surfacearea.py +++ b/deeprank2/features/surfacearea.py @@ -8,8 +8,6 @@ from deeprank2.molstruct.residue import Residue, SingleResidueVariant from deeprank2.utils.graph import Graph -# pylint: disable=c-extension-no-member - freesasa.setVerbosity(freesasa.nowarnings) logging.getLogger(__name__) @@ -22,13 +20,13 @@ def add_sasa(pdb_path: str, graph: Graph): if isinstance(node.id, Residue): residue = node.id selection = (f"residue, (resi {residue.number_string}) and (chain {residue.chain.id})",) - area = freesasa.selectArea(selection, structure, result)['residue'] + area = freesasa.selectArea(selection, structure, result)["residue"] elif isinstance(node.id, Atom): atom = node.id residue = atom.residue selection = (f"atom, (name {atom.name}) and (resi {residue.number_string}) and (chain {residue.chain.id})",) - area = freesasa.selectArea(selection, structure, result)['atom'] + area = freesasa.selectArea(selection, structure, result)["atom"] else: raise TypeError(f"Unexpected node type: {type(node.id)}") @@ -38,9 +36,7 @@ def add_sasa(pdb_path: str, graph: Graph): node.features[Nfeat.SASA] = area - def add_bsa(graph: Graph): - sasa_complete_structure = freesasa.Structure() sasa_chain_structures = {} @@ -52,12 +48,24 @@ def add_bsa(graph: Graph): sasa_chain_structures[chain_id] = freesasa.Structure() for atom in residue.atoms: - sasa_chain_structures[chain_id].addAtom(atom.name, atom.residue.amino_acid.three_letter_code, - atom.residue.number, atom.residue.chain.id, - atom.position[0], atom.position[1], atom.position[2]) - sasa_complete_structure.addAtom(atom.name, atom.residue.amino_acid.three_letter_code, - atom.residue.number, atom.residue.chain.id, - atom.position[0], atom.position[1], atom.position[2]) + sasa_chain_structures[chain_id].addAtom( + atom.name, + atom.residue.amino_acid.three_letter_code, + atom.residue.number, + atom.residue.chain.id, + atom.position[0], + atom.position[1], + atom.position[2], + ) + sasa_complete_structure.addAtom( + atom.name, + atom.residue.amino_acid.three_letter_code, + atom.residue.number, + atom.residue.chain.id, + atom.position[0], + atom.position[1], + atom.position[2], + ) elif isinstance(node.id, Atom): atom = node.id @@ -66,52 +74,58 @@ def add_bsa(graph: Graph): if chain_id not in sasa_chain_structures: sasa_chain_structures[chain_id] = freesasa.Structure() - sasa_chain_structures[chain_id].addAtom(atom.name, atom.residue.amino_acid.three_letter_code, - atom.residue.number, atom.residue.chain.id, - atom.position[0], atom.position[1], atom.position[2]) - sasa_complete_structure.addAtom(atom.name, atom.residue.amino_acid.three_letter_code, - atom.residue.number, atom.residue.chain.id, - atom.position[0], atom.position[1], atom.position[2]) + sasa_chain_structures[chain_id].addAtom( + atom.name, + atom.residue.amino_acid.three_letter_code, + atom.residue.number, + atom.residue.chain.id, + atom.position[0], + atom.position[1], + atom.position[2], + ) + sasa_complete_structure.addAtom( + atom.name, + atom.residue.amino_acid.three_letter_code, + atom.residue.number, + atom.residue.chain.id, + atom.position[0], + atom.position[1], + atom.position[2], + ) area_key = "atom" - selection = (f"atom, (name {atom.name}) and (resi {atom.residue.number_string}) and (chain {atom.residue.chain.id})") + selection = f"atom, (name {atom.name}) and (resi {atom.residue.number_string}) and (chain {atom.residue.chain.id})" else: raise TypeError(f"Unexpected node type: {type(node.id)}") sasa_complete_result = freesasa.calc(sasa_complete_structure) - sasa_chain_results = {chain_id: freesasa.calc(structure) - for chain_id, structure in sasa_chain_structures.items()} + sasa_chain_results = {chain_id: freesasa.calc(structure) for chain_id, structure in sasa_chain_structures.items()} for node in graph.nodes: if isinstance(node.id, Residue): residue = node.id chain_id = residue.chain.id area_key = "residue" - selection = ("residue, (resi %s) and (chain %s)" % (residue.number_string, residue.chain.id),) # pylint: disable=consider-using-f-string + selection = (f"residue, (resi {residue.number_string}) and (chain {residue.chain.id})",) elif isinstance(node.id, Atom): atom = node.id chain_id = atom.residue.chain.id area_key = "atom" - selection = ("atom, (name %s) and (resi %s) and (chain %s)" % \ - (atom.name, atom.residue.number_string, atom.residue.chain.id),) # pylint: disable=consider-using-f-string + selection = (f"atom, (name {atom.name}) and (resi {atom.residue.number_string}) and (chain {atom.residue.chain.id})",) - area_monomer = freesasa.selectArea(selection, sasa_chain_structures[chain_id], \ - sasa_chain_results[chain_id])[area_key] + area_monomer = freesasa.selectArea(selection, sasa_chain_structures[chain_id], sasa_chain_results[chain_id])[area_key] area_multimer = freesasa.selectArea(selection, sasa_complete_structure, sasa_complete_result)[area_key] node.features[Nfeat.BSA] = area_monomer - area_multimer -def add_features( # pylint: disable=unused-argument +def add_features( pdb_path: str, graph: Graph, - single_amino_acid_variant: SingleResidueVariant | None = None, + single_amino_acid_variant: SingleResidueVariant | None = None, # noqa: ARG001 (unused argument) ): - - """calculates the Buried Surface Area (BSA) and the Solvent Accessible Surface Area (SASA): - BSA: the area of the protein, that only gets exposed in monomeric state""" - + """Calculates the Buried Surface Area (BSA) and the Solvent Accessible Surface Area (SASA).""" # BSA add_bsa(graph) diff --git a/deeprank2/molstruct/aminoacid.py b/deeprank2/molstruct/aminoacid.py index 189b5a564..8400f9c70 100644 --- a/deeprank2/molstruct/aminoacid.py +++ b/deeprank2/molstruct/aminoacid.py @@ -20,9 +20,7 @@ def onehot(self): class AminoAcid: - """An amino acid represents the type of `Residue` in a `PDBStructure`.""" - - def __init__( # pylint: disable=too-many-arguments + def __init__( self, name: str, three_letter_code: str, @@ -36,7 +34,8 @@ def __init__( # pylint: disable=too-many-arguments hydrogen_bond_acceptors: int, index: int, ): - """ + """An amino acid represents the type of `Residue` in a `PDBStructure`. + Args: name (str): Full name of the amino acid. three_letter_code (str): Three-letter code of the amino acid (as in PDB). @@ -50,7 +49,6 @@ def __init__( # pylint: disable=too-many-arguments hydrogen_bond_acceptors (int): Number of hydrogen bond acceptors. index (int): The rank of the amino acid, used for computing one-hot encoding. """ - # amino acid nomenclature self._name = name self._three_letter_code = three_letter_code @@ -111,9 +109,7 @@ def hydrogen_bond_acceptors(self) -> int: @property def onehot(self) -> NDArray: if self._index is None: - raise ValueError( - f"Amino acid {self._name} index is not set, thus no onehot can be computed." - ) + raise ValueError(f"Amino acid {self._name} index is not set, thus no onehot can be computed.") # 20 canonical amino acids # selenocysteine and pyrrolysine are indexed as cysteine and lysine, respectively a = np.zeros(20) diff --git a/deeprank2/molstruct/atom.py b/deeprank2/molstruct/atom.py index 6bdef04de..7088ee4b4 100644 --- a/deeprank2/molstruct/atom.py +++ b/deeprank2/molstruct/atom.py @@ -1,17 +1,21 @@ from __future__ import annotations from enum import Enum +from typing import TYPE_CHECKING import numpy as np -from numpy.typing import NDArray -from deeprank2.molstruct.residue import Residue +if TYPE_CHECKING: + from numpy.typing import NDArray + + from deeprank2.molstruct.residue import Residue class AtomicElement(Enum): """One-hot encoding of the atomic element (or atom type).""" + C = 1 - O = 2 # noqa: pycodestyle + O = 2 # noqa: E741 (ambiguous-variable-name) N = 3 S = 4 P = 5 @@ -25,9 +29,7 @@ def onehot(self) -> np.array: class Atom: - """One atom in a PDBStructure.""" - - def __init__( # pylint: disable=too-many-arguments + def __init__( self, residue: Residue, name: str, @@ -35,7 +37,8 @@ def __init__( # pylint: disable=too-many-arguments position: NDArray, occupancy: float, ): - """ + """One atom in a PDBStructure. + Args: residue (:class:`Residue`): The residue that this atom belongs to. name (str): Pdb atom name. @@ -44,7 +47,7 @@ def __init__( # pylint: disable=too-many-arguments occupancy (float): Pdb occupancy value. This represents the proportion of structures where the atom is detected at a given position. Sometimes a single atom can be detected at multiple positions. In that case separate structures exist where sum(occupancy) == 1. - Note that only the highest occupancy atom is used by deeprank2 (see tools.pdb._add_atom_to_residue) + Note that only the highest occupancy atom is used by deeprank2 (see tools.pdb._add_atom_to_residue). """ self._residue = residue self._name = name @@ -53,9 +56,8 @@ def __init__( # pylint: disable=too-many-arguments self._occupancy = occupancy def __eq__(self, other) -> bool: - if isinstance (other, Atom): - return (self._residue == other._residue - and self._name == other._name) + if isinstance(other, Atom): + return self._residue == other._residue and self._name == other._name return NotImplemented def __hash__(self) -> hash: diff --git a/deeprank2/molstruct/pair.py b/deeprank2/molstruct/pair.py index bdf7f5423..41b732693 100644 --- a/deeprank2/molstruct/pair.py +++ b/deeprank2/molstruct/pair.py @@ -6,10 +6,9 @@ class Pair: - """A hashable, comparable object for any set of two inputs where order doesn't matter.""" - def __init__(self, item1: Any, item2: Any): - """ + """A hashable, comparable object for any set of two inputs where order doesn't matter. + Args: item1 (Any object): The pair's first object, must be convertable to string. item2 (Any object): The pair's second object, must be convertable to string. @@ -28,8 +27,7 @@ def __hash__(self) -> hash: def __eq__(self, other) -> bool: """Compare the pairs as sets, so the order doesn't matter.""" if isinstance(other, Pair): - return (self.item1 == other.item1 and self.item2 == other.item2 - or self.item1 == other.item2 and self.item2 == other.item1) + return self.item1 == other.item1 and self.item2 == other.item2 or self.item1 == other.item2 and self.item2 == other.item1 return NotImplemented def __iter__(self): @@ -37,7 +35,7 @@ def __iter__(self): return iter([self.item1, self.item2]) def __repr__(self) -> str: - return (str(self.item1) + str(self.item2)) + return str(self.item1) + str(self.item2) class Contact(Pair, ABC): diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index 552d1a015..2b1412676 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -3,24 +3,18 @@ from typing import TYPE_CHECKING import numpy as np -from numpy.typing import NDArray - -from deeprank2.molstruct.aminoacid import AminoAcid -from deeprank2.molstruct.structure import Chain -from deeprank2.utils.pssmdata import PssmRow if TYPE_CHECKING: + from numpy.typing import NDArray + + from deeprank2.molstruct.aminoacid import AminoAcid from deeprank2.molstruct.atom import Atom + from deeprank2.molstruct.structure import Chain + from deeprank2.utils.pssmdata import PssmRow class Residue: - """One protein residue in a `PDBStructure`. - - A `Residue` is the basic building block of proteins and protein complex, - here represented by `PDBStructures`. - Each residue is of a certain `AminoAcid` type and consists of multiple - `Atom`s. - """ + """One protein residue in a `PDBStructure`.""" def __init__( self, @@ -29,7 +23,11 @@ def __init__( amino_acid: AminoAcid | None = None, insertion_code: str | None = None, ): - """ + """One protein residue in a `PDBStructure`. + + A `Residue` is the basic building block of proteins and protein complex, here represented by `PDBStructures`. + Each `Residue` is of a certain `AminoAcid` type and consists of multiple `Atom`s. + Args: chain (:class:`Chain`): The chain that this residue belongs to. number (int): the residue number @@ -37,7 +35,6 @@ def __init__( Defaults to None. insertion_code (str, optional): The pdb insertion code, if any. Defaults to None. """ - self._chain = chain self._number = number self._amino_acid = amino_acid @@ -46,10 +43,7 @@ def __init__( def __eq__(self, other) -> bool: if isinstance(other, Residue): - return (self._chain == other._chain - and self._number == other._number - and self._insertion_code == other._insertion_code - ) + return self._chain == other._chain and self._number == other._number and self._insertion_code == other._insertion_code return NotImplemented def __hash__(self) -> hash: @@ -59,7 +53,7 @@ def get_pssm(self) -> PssmRow: """Load pssm info linked to the residue.""" pssm = self._chain.pssm if pssm is None: - raise FileNotFoundError(f'No pssm file found for Chain {self._chain}.') + raise FileNotFoundError(f"No pssm file found for Chain {self._chain}.") return pssm[self] @property @@ -116,16 +110,15 @@ def get_center(self) -> NDArray: return alphas[0].position if len(self.atoms) == 0: - raise ValueError(f"cannot get the center position from {self}, because it has no atoms") + raise ValueError(f"Cannot get the center position from {self}, because it has no atoms") return np.mean([atom.position for atom in self.atoms], axis=0) class SingleResidueVariant: - """A single residue mutation of a PDBStrcture.""" - def __init__(self, residue: Residue, variant_amino_acid: AminoAcid): - """ + """A single residue mutation of a PDBStrcture. + Args: residue (Residue): the `Residue` object from the PDBStructure that is mutated. variant_amino_acid (AminoAcid): the amino acid that the `Residue` is mutated into. diff --git a/deeprank2/molstruct/structure.py b/deeprank2/molstruct/structure.py index f1f5e93d7..a5dcbc401 100644 --- a/deeprank2/molstruct/structure.py +++ b/deeprank2/molstruct/structure.py @@ -2,24 +2,21 @@ from typing import TYPE_CHECKING -from deeprank2.utils.pssmdata import PssmRow - if TYPE_CHECKING: from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.residue import Residue + from deeprank2.utils.pssmdata import PssmRow class PDBStructure: - """A proitein or protein complex structure.. - - A `PDBStructure` can contain one or multiple `Chains`, i.e. separate - molecular entities (individual proteins). - One PDBStructure consists of a number of `Residue`s, each of which is of a - particular `AminoAcid` type and in turn consists of a number of `Atom`s. - """ + """.""" def __init__(self, id_: str | None = None): - """ + """A proitein or protein complex structure. + + A `PDBStructure` can contain one or multiple `Chains`, i.e. separate molecular entities (individual proteins). + One PDBStructure consists of a number of `Residue`s, each of which is of a particular `AminoAcid` type and in turn consists of a number of `Atom`s. + Args: id_ (str, optional): An unique identifier for this structure, can be the pdb accession code. Defaults to None. @@ -46,7 +43,7 @@ def get_chain(self, chain_id: str) -> Chain: def add_chain(self, chain: Chain): if chain.id in self._chains: - raise ValueError(f"duplicate chain: {chain.id}") + raise ValueError(f"Duplicate chain: {chain.id}") self._chains[chain.id] = chain @property @@ -75,9 +72,9 @@ class Chain: def __init__(self, model: PDBStructure, id_: str | None): """One chain of a PDBStructure. - Args: - model (:class:`PDBStructure`): The model that this chain is part of. - id_ (str): The pdb identifier of this chain. + Args: + model (:class:`PDBStructure`): The model that this chain is part of. + id_ (str): The pdb identifier of this chain. """ self._model = model self._id = id_ @@ -123,8 +120,7 @@ def get_atoms(self) -> list[Atom]: def __eq__(self, other) -> bool: if isinstance(other, Chain): - return (self._model == other._model - and self._id == other._id) + return self._model == other._model and self._id == other._id return NotImplemented def __hash__(self) -> hash: diff --git a/deeprank2/neuralnets/cnn/model3d.py b/deeprank2/neuralnets/cnn/model3d.py index cbae80b4d..45805691b 100644 --- a/deeprank2/neuralnets/cnn/model3d.py +++ b/deeprank2/neuralnets/cnn/model3d.py @@ -22,7 +22,6 @@ class CnnRegression(torch.nn.Module): - def __init__(self, num_features: int, box_shape: tuple[int]): super().__init__() @@ -47,14 +46,14 @@ def _forward_features(self, x): x = self.convlayer_001(x) x = F.relu(self.convlayer_002(x)) x = self.convlayer_003(x) - return x + return x # noqa:RET504 (unnecessary-assign) def forward(self, data): x = self._forward_features(data.x) x = x.view(x.size(0), -1) x = F.relu(self.fclayer_000(x)) x = self.fclayer_001(x) - return x + return x # noqa:RET504 (unnecessary-assign) ###################################################################### @@ -74,8 +73,8 @@ def forward(self, data): # fc layer 1: fc | input 84 output 1 post None # ---------------------------------------------------------------------- -class CnnClassification(torch.nn.Module): +class CnnClassification(torch.nn.Module): def __init__(self, num_features, box_shape): super().__init__() @@ -99,11 +98,11 @@ def _forward_features(self, x): x = self.convlayer_001(x) x = F.relu(self.convlayer_002(x)) x = self.convlayer_003(x) - return x + return x # noqa:RET504 (unnecessary-assign) def forward(self, data): x = self._forward_features(data.x) x = x.view(x.size(0), -1) x = F.relu(self.fclayer_000(x)) x = self.fclayer_001(x) - return x + return x # noqa:RET504 (unnecessary-assign) diff --git a/deeprank2/neuralnets/gnn/alignmentnet.py b/deeprank2/neuralnets/gnn/alignmentnet.py index 5cd3bfcbd..a53cfabbd 100644 --- a/deeprank2/neuralnets/gnn/alignmentnet.py +++ b/deeprank2/neuralnets/gnn/alignmentnet.py @@ -3,15 +3,16 @@ __author__ = "Daniel-Tobias Rademaker" + class GNNLayer(nn.Module): - def __init__( # pylint: disable=too-many-arguments + def __init__( self, nmb_edge_projection, nmb_hidden_attr, nmb_output_features, message_vector_length, nmb_mlp_neurons, - act_fn=nn.SiLU(), + act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) is_last_layer=True, ): super().__init__() @@ -54,11 +55,8 @@ def __init__( # pylint: disable=too-many-arguments # and node attributes in order to create a 'message vector'between those # nodes def edge_model(self, edge_attr, hidden_features_source, hidden_features_target): - cat = torch.cat( - [edge_attr, hidden_features_source, hidden_features_target], dim=1 - ) - output = self.edge_mlp(cat) - return output + cat = torch.cat([edge_attr, hidden_features_source, hidden_features_target], dim=1) + return self.edge_mlp(cat) # A function that updates the node-attributes. Assumed that submessages # are already summed @@ -88,9 +86,7 @@ def update_nodes(self, edges, edge_attr, hidden_features, steps=1): # It is possible to run input through the same same layer multiple # times for _ in range(steps): - node_pair_messages = self.edge_model( - edge_attr, h[row], h[col] - ) # get all atom-pair messages + node_pair_messages = self.edge_model(edge_attr, h[row], h[col]) # get all atom-pair messages # sum all messages per node to single message vector messages = self.sum_messages(edges, node_pair_messages, len(h)) # Use the messages to update the node-attributes @@ -107,7 +103,7 @@ def output(self, hidden_features, get_attention=True): class SuperGNN(nn.Module): - def __init__( # pylint: disable=too-many-arguments + def __init__( self, nmb_edge_attr, nmb_node_attr, @@ -117,7 +113,7 @@ def __init__( # pylint: disable=too-many-arguments nmb_gnn_layers, nmb_output_features, message_vector_length, - act_fn=nn.SiLU(), + act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) ): super().__init__() @@ -164,21 +160,18 @@ def preprocess(self, edge_attr, node_attr): # Runs data through layers and return output. Potentially, attention can # also be returned - def run_through_network( - self, edges, edge_attr, node_attr, with_output_attention=False - ): + def run_through_network(self, edges, edge_attr, node_attr, with_output_attention=False): edge_attr, node_attr = self.preprocess(edge_attr, node_attr) for layer in self.modlist: node_attr = layer.update_nodes(edges, edge_attr, node_attr) if with_output_attention: - representations, attention = self.modlist[-1].output(node_attr, True) + representations, attention = self.modlist[-1].output(node_attr, True) # noqa: FBT003 (boolean-positional-value-in-call) return representations, attention - representations = self.modlist[-1].output(node_attr, True) - return representations + return self.modlist[-1].output(node_attr, True) # noqa: FBT003 (boolean-positional-value-in-call) class AlignmentGNN(SuperGNN): - def __init__( # pylint: disable=too-many-arguments + def __init__( self, nmb_edge_attr, nmb_node_attr, @@ -188,7 +181,7 @@ def __init__( # pylint: disable=too-many-arguments nmb_mlp_neurons, nmb_gnn_layers, nmb_edge_projection, - act_fn=nn.SiLU(), + act_fn=nn.SiLU(), # noqa: B008 (function-call-in-default-argument) ): super().__init__( nmb_edge_attr, @@ -204,5 +197,4 @@ def __init__( # pylint: disable=too-many-arguments # Run over all layers, and return the ouput vectors def forward(self, edges, edge_attr, node_attr): - representations = self.run_through_network(edges, edge_attr, node_attr) - return representations + return self.run_through_network(edges, edge_attr, node_attr) diff --git a/deeprank2/neuralnets/gnn/foutnet.py b/deeprank2/neuralnets/gnn/foutnet.py index 9fa03648c..216ef6b9d 100644 --- a/deeprank2/neuralnets/gnn/foutnet.py +++ b/deeprank2/neuralnets/gnn/foutnet.py @@ -1,20 +1,20 @@ import torch import torch.nn.functional as F -from deeprank2.utils.community_pooling import (community_pooling, - get_preloaded_cluster) from torch import nn from torch.nn import Parameter from torch_geometric.nn import max_pool_x from torch_geometric.nn.inits import uniform from torch_scatter import scatter_mean +from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster + class FoutLayer(torch.nn.Module): + """FoutLayer. - """ This layer is described by eq. (1) of Protein Interface Predition using Graph Convolutional Network - by Alex Fout et al. NIPS 2018 + by Alex Fout et al. NIPS 2018. Args: in_channels (int): Size of each input sample. @@ -24,7 +24,6 @@ class FoutLayer(torch.nn.Module): """ def __init__(self, in_channels: int, out_channels: int, bias: bool = True): - super().__init__() self.in_channels = in_channels @@ -48,7 +47,6 @@ def reset_parameters(self): uniform(size, self.bias) def forward(self, x, edge_index): - num_node = len(x) # alpha = x * Wc @@ -78,7 +76,12 @@ def __repr__(self): class FoutNet(torch.nn.Module): - def __init__(self, input_shape, output_shape=1, input_shape_edge=None): # pylint: disable=unused-argument + def __init__( + self, + input_shape, + output_shape=1, + input_shape_edge=None, # noqa: ARG002 (unused argument) + ): super().__init__() self.conv1 = FoutLayer(input_shape, 16) @@ -90,7 +93,6 @@ def __init__(self, input_shape, output_shape=1, input_shape_edge=None): # pylint self.clustering = "mcl" def forward(self, data): - act = nn.Tanhshrink() act = F.relu # act = nn.LeakyReLU(0.25) @@ -111,5 +113,5 @@ def forward(self, data): x = self.fc2(x) # x = F.dropout(x, training=self.training) - return x + return x # noqa:RET504 (unnecessary-assign) # return F.relu(x) diff --git a/deeprank2/neuralnets/gnn/ginet.py b/deeprank2/neuralnets/gnn/ginet.py index 778ff4dbb..95d897aff 100644 --- a/deeprank2/neuralnets/gnn/ginet.py +++ b/deeprank2/neuralnets/gnn/ginet.py @@ -1,39 +1,32 @@ import torch import torch.nn.functional as F -from deeprank2.utils.community_pooling import (community_pooling, - get_preloaded_cluster) from torch import nn from torch_geometric.nn import max_pool_x from torch_geometric.nn.inits import uniform from torch_scatter import scatter_mean, scatter_sum +from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster + class GINetConvLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): - super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.fc = nn.Linear(self.in_channels, self.out_channels, bias=bias) - self.fc_edge_attr = nn.Linear( - number_edge_features, number_edge_features, bias=bias - ) - self.fc_attention = nn.Linear( - 2 * self.out_channels + number_edge_features, 1, bias=bias - ) + self.fc_edge_attr = nn.Linear(number_edge_features, number_edge_features, bias=bias) + self.fc_attention = nn.Linear(2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters() def reset_parameters(self): - size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) uniform(size, self.fc_edge_attr.weight) def forward(self, x, edge_index, edge_attr): - row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr @@ -53,7 +46,7 @@ def forward(self, x, edge_index, edge_attr): out = torch.zeros(num_node, self.out_channels).to(alpha.device) z = scatter_sum(h, row, dim=0, out=out) - return z + return z # noqa:RET504 (unnecessary-assign) def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" @@ -93,16 +86,12 @@ def forward(self, data): # INTERNAL INTERACTION GRAPH # first conv block - data_ext.x = act( - self.conv1_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr) - ) + data_ext.x = act(self.conv1_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster(data_ext.cluster0, data_ext.batch) data_ext = community_pooling(cluster, data_ext) # second conv block - data_ext.x = act( - self.conv2_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr) - ) + data_ext.x = act(self.conv2_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster(data_ext.cluster1, data_ext.batch) x_ext, batch_ext = max_pool_x(cluster, data_ext.x, data_ext.batch) @@ -115,4 +104,4 @@ def forward(self, data): x = F.dropout(x, self.dropout, training=self.training) x = self.fc2(x) - return x + return x # noqa:RET504 (unnecessary-assign) diff --git a/deeprank2/neuralnets/gnn/ginet_nocluster.py b/deeprank2/neuralnets/gnn/ginet_nocluster.py index c84dbf14d..849d617f3 100644 --- a/deeprank2/neuralnets/gnn/ginet_nocluster.py +++ b/deeprank2/neuralnets/gnn/ginet_nocluster.py @@ -7,30 +7,23 @@ class GINetConvLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): - super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.fc = nn.Linear(self.in_channels, self.out_channels, bias=bias) - self.fc_edge_attr = nn.Linear( - number_edge_features, number_edge_features, bias=bias - ) - self.fc_attention = nn.Linear( - 2 * self.out_channels + number_edge_features, 1, bias=bias - ) + self.fc_edge_attr = nn.Linear(number_edge_features, number_edge_features, bias=bias) + self.fc_attention = nn.Linear(2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters() def reset_parameters(self): - size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) uniform(size, self.fc_edge_attr.weight) def forward(self, x, edge_index, edge_attr): - row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr @@ -50,7 +43,7 @@ def forward(self, x, edge_index, edge_attr): out = torch.zeros(num_node, self.out_channels).to(alpha.device) z = scatter_sum(h, row, dim=0, out=out) - return z + return z # noqa:RET504 (unnecessary-assign) def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})" @@ -85,14 +78,10 @@ def forward(self, data): # INTERNAL INTERACTION GRAPH # first conv block - data_ext.x = act( - self.conv1_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr) - ) + data_ext.x = act(self.conv1_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) # second conv block - data_ext.x = act( - self.conv2_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr) - ) + data_ext.x = act(self.conv2_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) # FC x = scatter_mean(data.x, data.batch, dim=0) @@ -103,4 +92,4 @@ def forward(self, data): x = F.dropout(x, self.dropout, training=self.training) x = self.fc2(x) - return x + return x # noqa:RET504 (unnecessary-assign) diff --git a/deeprank2/neuralnets/gnn/naive_gnn.py b/deeprank2/neuralnets/gnn/naive_gnn.py index 338000e12..873f3005c 100644 --- a/deeprank2/neuralnets/gnn/naive_gnn.py +++ b/deeprank2/neuralnets/gnn/naive_gnn.py @@ -6,7 +6,6 @@ class NaiveConvolutionalLayer(Module): - def __init__(self, count_node_features, count_edge_features): super().__init__() message_size = 32 @@ -27,13 +26,13 @@ def forward(self, node_features, edge_node_indices, edge_features): message_sums_per_node = scatter_sum(messages_per_neighbour, node0_indices, dim=0, out=out) # update nodes node_input = torch.cat([node_features, message_sums_per_node], dim=1) - node_output = self._node_mlp(node_input) - return node_output + return self._node_mlp(node_input) -class NaiveNetwork(Module): +class NaiveNetwork(Module): def __init__(self, input_shape: int, output_shape: int, input_shape_edge: int): - """ + """NaiveNetwork. + Args: input_shape (int): Number of node input features. output_shape (int): Number of output value per graph. @@ -51,4 +50,4 @@ def forward(self, data): means_per_graph_external = scatter_mean(external_updated2_node_features, data.batch, dim=0) graph_input = means_per_graph_external z = self._graph_mlp(graph_input) - return z + return z # noqa:RET504 (unnecessary-assign) diff --git a/deeprank2/neuralnets/gnn/sgat.py b/deeprank2/neuralnets/gnn/sgat.py index 1e81eba5e..e321ff465 100644 --- a/deeprank2/neuralnets/gnn/sgat.py +++ b/deeprank2/neuralnets/gnn/sgat.py @@ -1,17 +1,17 @@ import torch import torch.nn.functional as F -from deeprank2.utils.community_pooling import (community_pooling, - get_preloaded_cluster) from torch import nn from torch.nn import Parameter from torch_geometric.nn import max_pool_x from torch_geometric.nn.inits import uniform from torch_scatter import scatter_mean +from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster + class SGraphAttentionLayer(torch.nn.Module): + """SGraphAttentionLayer. - """ This is a new layer that is similar to the graph attention network but simpler z_i = 1 / Ni \\Sum_j a_ij * [x_i || x_j] * W + b_i || is the concatenation operator: [1,2,3] || [4,5,6] = [1,2,3,4,5,6] @@ -23,10 +23,9 @@ class SGraphAttentionLayer(torch.nn.Module): out_channels (int): Size of each output sample. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. Defaults to True. - """ + """ # noqa: D301 (escape-sequence-in-docstring) def __init__(self, in_channels: int, out_channels: int, bias: bool = True, undirected=True): - super().__init__() self.in_channels = in_channels @@ -48,7 +47,6 @@ def reset_parameters(self): uniform(size, self.bias) def forward(self, x, edge_index, edge_attr): - row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr @@ -83,7 +81,12 @@ def __repr__(self): class SGAT(torch.nn.Module): - def __init__(self, input_shape, output_shape=1, input_shape_edge=None): # pylint: disable=unused-argument + def __init__( + self, + input_shape, + output_shape=1, + input_shape_edge=None, # noqa: ARG002 (unused argument) + ): super().__init__() self.conv1 = SGraphAttentionLayer(input_shape, 16) @@ -95,7 +98,6 @@ def __init__(self, input_shape, output_shape=1, input_shape_edge=None): # pylint self.clustering = "mcl" def forward(self, data): - act = nn.Tanhshrink() act = F.relu # act = nn.LeakyReLU(0.25) @@ -116,5 +118,5 @@ def forward(self, data): x = self.fc2(x) # x = F.dropout(x, training=self.training) - return x + return x # noqa:RET504 (unnecessary-assign) # return F.relu(x) diff --git a/deeprank2/query.py b/deeprank2/query.py index f0fefd113..484074157 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -5,13 +5,14 @@ import pkgutil import re import warnings +from collections.abc import Iterator from dataclasses import MISSING, dataclass, field, fields from functools import partial from glob import glob from multiprocessing import Pool from random import randrange from types import ModuleType -from typing import Iterator, Literal +from typing import Literal import h5py import numpy as np @@ -23,15 +24,14 @@ from deeprank2.molstruct.aminoacid import AminoAcid from deeprank2.molstruct.residue import Residue, SingleResidueVariant from deeprank2.molstruct.structure import PDBStructure -from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure, - get_surrounding_residues) +from deeprank2.utils.buildgraph import get_contact_atoms, get_structure, get_surrounding_residues from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod from deeprank2.utils.parsing.pssm import parse_pssm _log = logging.getLogger(__name__) -VALID_RESOLUTIONS = ['atom', 'residue'] +VALID_RESOLUTIONS = ["atom", "residue"] @dataclass(repr=False, kw_only=True) @@ -56,7 +56,7 @@ class Query: """ pdb_path: str - resolution: Literal['residue', 'atom'] + resolution: Literal["residue", "atom"] chain_ids: list[str] | str pssm_paths: dict[str, str] = field(default_factory=dict) targets: dict[str, float] = field(default_factory=dict) @@ -68,10 +68,10 @@ def __post_init__(self): self._model_id = os.path.splitext(os.path.basename(self.pdb_path))[0] self.variant = None # not used for PPI, overwritten for SRV - if self.resolution == 'residue': + if self.resolution == "residue": self.max_edge_length = 10 if not self.max_edge_length else self.max_edge_length self.influence_radius = 10 if not self.influence_radius else self.influence_radius - elif self.resolution == 'atom': + elif self.resolution == "atom": self.max_edge_length = 4.5 if not self.max_edge_length else self.max_edge_length self.influence_radius = 4.5 if not self.influence_radius else self.influence_radius else: @@ -97,7 +97,7 @@ def _load_structure(self) -> PDBStructure: try: structure = get_structure(pdb, self.model_id) finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private-member-access) # read the pssm if self._pssm_required: self._load_pssm_data(structure) @@ -109,10 +109,10 @@ def _load_pssm_data(self, structure: PDBStructure): for chain in structure.chains: if chain.id in self.pssm_paths: pssm_path = self.pssm_paths[chain.id] - with open(pssm_path, "rt", encoding="utf-8") as f: + with open(pssm_path, encoding="utf-8") as f: chain.pssm = parse_pssm(f, chain) - def _check_pssm(self, verbosity: Literal[0,1,2] = 0): + def _check_pssm(self, verbosity: Literal[0, 1, 2] = 0): """Checks whether information stored in pssm file matches the corresponding pdb file. Args: @@ -128,18 +128,16 @@ def _check_pssm(self, verbosity: Literal[0,1,2] = 0): ValueError: Raised if info between pdb file and pssm file doesn't match or if no pssms were provided """ if not self.pssm_paths: - raise ValueError('No pssm paths provided for conservation feature module.') + raise ValueError("No pssm paths provided for conservation feature module.") # load residues from pssm and pdb files pssm_file_residues = {} for chain, pssm_path in self.pssm_paths.items(): - with open(pssm_path, encoding='utf-8') as f: + with open(pssm_path, encoding="utf-8") as f: lines = f.readlines()[1:] for line in lines: pssm_file_residues[chain + line.split()[0].zfill(4)] = convert_aa_nomenclature(line.split()[1], 3) - pdb_file_residues = {res[0] + str(res[2]).zfill(4): res[1] - for res in pdb2sql.pdb2sql(self.pdb_path).get_residues() - if res[0] in self.pssm_paths} + pdb_file_residues = {res[0] + str(res[2]).zfill(4): res[1] for res in pdb2sql.pdb2sql(self.pdb_path).get_residues() if res[0] in self.pssm_paths} # list errors mismatches = [] @@ -148,21 +146,21 @@ def _check_pssm(self, verbosity: Literal[0,1,2] = 0): try: if pdb_file_residues[residue] != pssm_file_residues[residue]: mismatches.append(residue) - except KeyError: + except KeyError: # noqa: PERF203 (try-except-in-loop) missing_entries.append(residue) # generate error message if len(mismatches) + len(missing_entries) > 0: - error_message = f'Amino acids in PSSM files do not match pdb file for {os.path.split(self.pdb_path)[1]}.' + error_message = f"Amino acids in PSSM files do not match pdb file for {os.path.split(self.pdb_path)[1]}." if verbosity: if len(mismatches) > 0: - error_message = error_message + f'\n\t{len(mismatches)} entries are incorrect.' + error_message = error_message + f"\n\t{len(mismatches)} entries are incorrect." if verbosity == 2: - error_message = error_message[-1] + f':\n\t{missing_entries}' + error_message = error_message[-1] + f":\n\t{missing_entries}" if len(missing_entries) > 0: - error_message = error_message + f'\n\t{len(missing_entries)} entries are missing.' + error_message = error_message + f"\n\t{len(missing_entries)} entries are missing." if verbosity == 2: - error_message = error_message[-1] + f':\n\t{missing_entries}' + error_message = error_message[-1] + f":\n\t{missing_entries}" # raise exception (or warning) if not self.suppress_pssm_errors: @@ -174,6 +172,7 @@ def _check_pssm(self, verbosity: Literal[0,1,2] = 0): def model_id(self) -> str: """The ID of the model, usually a .PDB accession code.""" return self._model_id + @model_id.setter def model_id(self, value: str): self._model_id = value @@ -194,12 +193,9 @@ def build( Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - if not isinstance(feature_modules, list): feature_modules = [feature_modules] - feature_modules = [importlib.import_module('deeprank2.features.' + module) - if isinstance(module, str) else module - for module in feature_modules] + feature_modules = [importlib.import_module("deeprank2.features." + module) if isinstance(module, str) else module for module in feature_modules] self._pssm_required = conservation in feature_modules graph = self._build_helper() @@ -212,6 +208,7 @@ def build( def _build_helper(self) -> Graph: raise NotImplementedError("Must be defined in child classes.") + def get_query_id(self) -> str: raise NotImplementedError("Must be defined in child classes.") @@ -251,8 +248,7 @@ def __post_init__(self): super().__post_init__() # calls __post_init__ of parents if len(self.chain_ids) != 1: - raise ValueError("`chain_ids` must contain exactly 1 chain for `SingleResidueVariantQuery` objects, " - + f"but {len(self.chain_ids)} were given.") + raise ValueError("`chain_ids` must contain exactly 1 chain for `SingleResidueVariantQuery` objects, " + f"but {len(self.chain_ids)} were given.") self.variant_chain_id = self.chain_ids[0] @property @@ -264,10 +260,11 @@ def residue_id(self) -> str: def get_query_id(self) -> str: """Returns the string representing the complete query ID.""" - return (f"{self.resolution}-srv:" - + f"{self.variant_chain_id}:{self.residue_id}:" - + f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}" - ) + return ( + f"{self.resolution}-srv:" + f"{self.variant_chain_id}:{self.residue_id}:" + f"{self.wildtype_amino_acid.name}->{self.variant_amino_acid.name}:{self.model_id}" + ) def _build_helper(self) -> Graph: """Helper function to build a graph for SRV queries. @@ -275,32 +272,34 @@ def _build_helper(self) -> Graph: Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - # load .PDB structure structure = self._load_structure() # find the variant residue and its surroundings variant_residue: Residue = None for residue in structure.get_chain(self.variant_chain_id).residues: - if ( - residue.number == self.variant_residue_number - and residue.insertion_code == self.insertion_code - ): + if residue.number == self.variant_residue_number and residue.insertion_code == self.insertion_code: variant_residue = residue break if variant_residue is None: - raise ValueError( - f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}" - ) + raise ValueError(f"Residue not found in {self.pdb_path}: {self.variant_chain_id} {self.residue_id}") self.variant = SingleResidueVariant(variant_residue, self.variant_amino_acid) - residues = get_surrounding_residues(structure, variant_residue, self.influence_radius) + residues = get_surrounding_residues( + structure, + variant_residue, + self.influence_radius, + ) # build the graph - if self.resolution == 'residue': - graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length) - elif self.resolution == 'atom': + if self.resolution == "residue": + graph = Graph.build_graph( + residues, + self.get_query_id(), + self.max_edge_length, + ) + elif self.resolution == "atom": residues.append(variant_residue) - atoms = set([]) + atoms = set() for residue in residues: if residue.amino_acid is not None: for atom in residue.atoms: @@ -338,14 +337,15 @@ def __post_init__(self): super().__post_init__() if len(self.chain_ids) != 2: - raise ValueError("`chain_ids` must contain exactly 2 chains for `ProteinProteinInterfaceQuery` objects, " - + f"but {len(self.chain_ids)} was/were given.") + raise ValueError( + "`chain_ids` must contain exactly 2 chains for `ProteinProteinInterfaceQuery` objects, " + f"but {len(self.chain_ids)} was/were given." + ) def get_query_id(self) -> str: """Returns the string representing the complete query ID.""" return ( f"{self.resolution}-ppi:" # resolution and query type (ppi for protein protein interface) - + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" + f"{self.chain_ids[0]}-{self.chain_ids[1]}:{self.model_id}" ) def _build_helper(self) -> Graph: @@ -354,22 +354,29 @@ def _build_helper(self) -> Graph: Returns: :class:`Graph`: The resulting :class:`Graph` object with all the features and targets. """ - # find the atoms near the contact interface contact_atoms = get_contact_atoms( self.pdb_path, self.chain_ids, - self.influence_radius + self.influence_radius, ) if len(contact_atoms) == 0: - raise ValueError("no contact atoms found") + raise ValueError("No contact atoms found") # build the graph - if self.resolution == 'atom': - graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length) - elif self.resolution == 'residue': + if self.resolution == "atom": + graph = Graph.build_graph( + contact_atoms, + self.get_query_id(), + self.max_edge_length, + ) + elif self.resolution == "residue": residues_selected = list({atom.residue for atom in contact_atoms}) - graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph( + residues_selected, + self.get_query_id(), + self.max_edge_length, + ) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model @@ -418,10 +425,9 @@ def add( verbose(bool): For logging query IDs added. Defaults to `False`. warn_duplicate (bool): Log a warning before renaming if a duplicate query is identified. Defaults to `True`. """ - query_id = query.get_query_id() if verbose: - _log.info(f'Adding query with ID {query_id}.') + _log.info(f"Adding query with ID {query_id}.") if query_id not in self._ids_count: self._ids_count[query_id] = 1 @@ -430,7 +436,7 @@ def add( new_id = query.model_id + "_" + str(self._ids_count[query_id]) query.model_id = new_id if warn_duplicate: - _log.warning(f'Query with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}') + _log.warning(f"Query with ID {query_id} has already been added to the collection. Renaming it as {query.get_query_id()}") self._queries.append(query) @@ -440,7 +446,6 @@ def export_dict(self, dataset_path: str): Args: dataset_path (str): The path where to save the list of queries. """ - with open(dataset_path, "wb") as pkl_file: pickle.dump(self, pkl_file) @@ -459,35 +464,48 @@ def __len__(self) -> int: return len(self._queries) def _process_one_query(self, query: Query): - """Only one process may access an hdf5 file at a time""" - + """Only one process may access an hdf5 file at a time.""" try: output_path = f"{self._prefix}-{os.getpid()}.hdf5" graph = query.build(self._feature_modules) graph.write_to_hdf5(output_path) if self._grid_settings is not None and self._grid_map_method is not None: - graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method) + graph.write_as_grid_to_hdf5( + output_path, + self._grid_settings, + self._grid_map_method, + ) for _ in range(self._grid_augmentation_count): # repeat with random augmentation axis, angle = pdb2sql.transform.get_rot_axis_angle(randrange(100)) augmentation = Augmentation(axis, angle) - graph.write_as_grid_to_hdf5(output_path, self._grid_settings, self._grid_map_method, augmentation) + graph.write_as_grid_to_hdf5( + output_path, + self._grid_settings, + self._grid_map_method, + augmentation, + ) except (ValueError, AttributeError, KeyError, TimeoutError) as e: - _log.warning(f'\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e}),' - ' and it has not been written to the hdf5 file. More details below:') + _log.warning( + f"\nGraph/Query with ID {query.get_query_id()} ran into an Exception ({e.__class__.__name__}: {e})," + " and it has not been written to the hdf5 file. More details below:" + ) _log.exception(e) - def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-default-value + def process( self, prefix: str = "processed-queries", - feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] = [components, contact], + feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"] = [ # noqa: B006, PYI051 (mutable-argument-default, redundant-literal-union) + components, + contact, + ], cpu_count: int | None = None, combine_output: bool = True, grid_settings: GridSettings | None = None, grid_map_method: MapMethod | None = None, - grid_augmentation_count: int = 0 + grid_augmentation_count: int = 0, ) -> list[str]: """Render queries into graphs (and optionally grids). @@ -514,18 +532,17 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de Returns: list[str]: The list of paths of the generated HDF5 files. """ - # set defaults - self._prefix = "processed-queries" if not prefix else re.sub('.hdf5$', '', prefix) # scrape extension if present + self._prefix = "processed-queries" if not prefix else re.sub(".hdf5$", "", prefix) # scrape extension if present max_cpus = os.cpu_count() self._cpu_count = max_cpus if cpu_count is None else min(cpu_count, max_cpus) if cpu_count and self._cpu_count < cpu_count: - _log.warning(f'\nTried to set {cpu_count} CPUs, but only {max_cpus} are present in the system.') - _log.info(f'\nNumber of CPUs for processing the queries set to: {self._cpu_count}.') + _log.warning(f"\nTried to set {cpu_count} CPUs, but only {max_cpus} are present in the system.") + _log.info(f"\nNumber of CPUs for processing the queries set to: {self._cpu_count}.") self._feature_modules = self._set_feature_modules(feature_modules) - _log.info(f'\nSelected feature modules: {self._feature_modules}.') + _log.info(f"\nSelected feature modules: {self._feature_modules}.") self._grid_settings = grid_settings self._grid_map_method = grid_map_method @@ -534,16 +551,16 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de raise ValueError(f"`grid_augmentation_count` cannot be negative, but was given as {grid_augmentation_count}") self._grid_augmentation_count = grid_augmentation_count - _log.info(f'Creating pool function to process {len(self)} queries...') + _log.info(f"Creating pool function to process {len(self)} queries...") pool_function = partial(self._process_one_query) with Pool(self._cpu_count) as pool: - _log.info('Starting pooling...\n') + _log.info("Starting pooling...\n") pool.map(pool_function, self.queries) output_paths = glob(f"{prefix}-*.hdf5") if combine_output: for output_path in output_paths: - with h5py.File(f"{prefix}.hdf5",'a') as f_dest, h5py.File(output_path,'r') as f_src: + with h5py.File(f"{prefix}.hdf5", "a") as f_dest, h5py.File(output_path, "r") as f_src: for key, value in f_src.items(): _log.debug(f"copy {key} from {output_path} to {prefix}.hdf5") f_src.copy(value, f_dest) @@ -552,27 +569,24 @@ def process( # pylint: disable=too-many-arguments, too-many-locals, dangerous-de return output_paths - def _set_feature_modules( - self, - feature_modules: list[ModuleType, str] | ModuleType | str | Literal['all'] - ) -> list[str]: + def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"]) -> list[str]: # noqa: PYI051 (redundant-literal-union) """Convert `feature_modules` to list[str] irrespective of input type. Raises: TypeError: if an invalid input type is passed. """ - - if feature_modules == 'all': + if feature_modules == "all": return [modname for _, modname, _ in pkgutil.iter_modules(deeprank2.features.__path__)] if isinstance(feature_modules, ModuleType): return [os.path.basename(feature_modules.__file__)[:-3]] if isinstance(feature_modules, str): - return [re.sub('.py$', '', feature_modules)] # scrape extension if present + return [re.sub(".py$", "", feature_modules)] # scrape extension if present if isinstance(feature_modules, list): - invalid_inputs = [type(el) for el in feature_modules if not isinstance(el, (str, ModuleType))] + invalid_inputs = [type(el) for el in feature_modules if not isinstance(el, str | ModuleType)] if invalid_inputs: - raise TypeError(f'`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted.') - return [re.sub('.py$', '', m) if isinstance(m, str) - else os.path.basename(m.__file__)[:-3] # for ModuleTypes - for m in feature_modules] - raise TypeError(f'`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted.') + raise TypeError(f"`feature_modules` contains invalid input ({invalid_inputs}). Only `str` and `ModuleType` are accepted.") + return [ + re.sub(".py$", "", m) if isinstance(m, str) else os.path.basename(m.__file__)[:-3] # for ModuleTypes + for m in feature_modules + ] + raise TypeError(f"`feature_modules` has received an invalid input type: {type(feature_modules)}. Only `str` and `ModuleType` are accepted.") diff --git a/deeprank2/tools/target.py b/deeprank2/tools/target.py index 52e365220..059712c7d 100644 --- a/deeprank2/tools/target.py +++ b/deeprank2/tools/target.py @@ -31,64 +31,54 @@ def add_target( 1ATN_xxx-3 0 1ATN_xxx-4 0 """ - target_dict = {} labels = np.loadtxt(target_list, delimiter=sep, usecols=[0], dtype=str) values = np.loadtxt(target_list, delimiter=sep, usecols=[1]) - for label, value in zip(labels, values): + for label, value in zip(labels, values, strict=True): target_dict[label] = value - # if a directory is provided if os.path.isdir(graph_path): graphs = glob.glob(f"{graph_path}/*.hdf5") - - # if a single file is provided elif os.path.isfile(graph_path): graphs = [graph_path] - - # if a list of file is provided + elif isinstance(graph_path, list): + graphs = graph_path else: - assert isinstance(graph_path, list) - assert os.path.isfile(graph_path[0]) + raise TypeError("Incorrect input passed.") for hdf5 in graphs: print(hdf5) + if not os.path.isfile(hdf5): + raise FileNotFoundError(f"File {hdf5} not found.") + try: f5 = h5py.File(hdf5, "a") - - for model, _ in target_dict.items(): + for model in target_dict: if model not in f5: - raise ValueError( - f"{hdf5} does not contain an entry named {model}" - ) - + raise ValueError(f"{hdf5} does not contain an entry named {model}.") # noqa: TRY301 (raise-within-try) try: model_gp = f5[model] - if targets.VALUES not in model_gp: model_gp.create_group(targets.VALUES) - group = f5[f"{model}/{targets.VALUES}/"] - - if target_name in group.keys(): + if target_name in group: # Delete the target if it already existed del group[target_name] - # Create the target group.create_dataset(target_name, data=target_dict[model]) - - except BaseException: + except BaseException: # noqa: BLE001 (blind-except) print(f"no graph for {model}") - f5.close() - except BaseException: + except BaseException: # noqa: BLE001 (blind-except) print(f"no graph for {hdf5}") -def compute_ppi_scores(pdb_path: str, reference_pdb_path: str) -> dict[str, float | int]: - +def compute_ppi_scores( + pdb_path: str, + reference_pdb_path: str, +) -> dict[str, float | int]: """Compute structure similarity scores for the input docking model and return them as a dictionary. The computed scores are: `lrmsd` (ligand root mean square deviation), `irmsd` (interface rmsd), @@ -102,20 +92,19 @@ def compute_ppi_scores(pdb_path: str, reference_pdb_path: str) -> dict[str, floa Returns: a dictionary containing values for lrmsd, irmsd, fnat, dockq, binary, capri_class. """ - ref_name = os.path.splitext(os.path.basename(reference_pdb_path))[0] - sim = StructureSimilarity(pdb_path, reference_pdb_path, enforce_residue_matching=False) + sim = StructureSimilarity( + pdb_path, + reference_pdb_path, + enforce_residue_matching=False, + ) scores = {} # Input pre-computed zone files if os.path.exists(ref_name + ".lzone"): - scores[targets.LRMSD] = sim.compute_lrmsd_fast( - method="svd", lzone=ref_name + ".lzone" - ) - scores[targets.IRMSD] = sim.compute_irmsd_fast( - method="svd", izone=ref_name + ".izone" - ) + scores[targets.LRMSD] = sim.compute_lrmsd_fast(method="svd", lzone=ref_name + ".lzone") + scores[targets.IRMSD] = sim.compute_irmsd_fast(method="svd", izone=ref_name + ".izone") # Compute zone files else: @@ -123,13 +112,11 @@ def compute_ppi_scores(pdb_path: str, reference_pdb_path: str) -> dict[str, floa scores[targets.IRMSD] = sim.compute_irmsd_fast(method="svd") scores[targets.FNAT] = sim.compute_fnat_fast() - scores[targets.DOCKQ] = sim.compute_DockQScore( - scores[targets.FNAT], scores[targets.LRMSD], scores[targets.IRMSD] - ) + scores[targets.DOCKQ] = sim.compute_DockQScore(scores[targets.FNAT], scores[targets.LRMSD], scores[targets.IRMSD]) scores[targets.BINARY] = scores[targets.IRMSD] < 4.0 scores[targets.CAPRI] = 4 - for thr, val in zip([4.0, 2.0, 1.0], [3, 2, 1]): + for thr, val in zip([4.0, 2.0, 1.0], [3, 2, 1], strict=True): if scores[targets.IRMSD] < thr: scores[targets.CAPRI] = val diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index c5f0fc030..430e6cf01 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -16,30 +16,29 @@ from deeprank2.dataset import GraphDataset, GridDataset from deeprank2.domain import losstypes as losses from deeprank2.domain import targetstorage as targets -from deeprank2.utils.community_pooling import (community_detection, - community_pooling) +from deeprank2.utils.community_pooling import community_detection, community_pooling from deeprank2.utils.earlystopping import EarlyStopping -from deeprank2.utils.exporters import (HDF5OutputExporter, OutputExporter, - OutputExporterCollection) +from deeprank2.utils.exporters import HDF5OutputExporter, OutputExporter, OutputExporterCollection +# ruff: noqa: PYI041 (redundant-numeric-union), they are used differently in this module _log = logging.getLogger(__name__) -class Trainer(): - def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 - self, - neuralnet = None, - dataset_train: GraphDataset | GridDataset | None = None, - dataset_val: GraphDataset | GridDataset | None = None, - dataset_test: GraphDataset | GridDataset | None = None, - val_size: float | int | None = None, - test_size: float | int | None = None, - class_weights: bool = False, - pretrained_model: str | None = None, - cuda: bool = False, - ngpu: int = 0, - output_exporters: list[OutputExporter] | None = None, - ): +class Trainer: + def __init__( # noqa: PLR0915 (too-many-statements) + self, + neuralnet=None, + dataset_train: GraphDataset | GridDataset | None = None, + dataset_val: GraphDataset | GridDataset | None = None, + dataset_test: GraphDataset | GridDataset | None = None, + val_size: float | int | None = None, + test_size: float | int | None = None, + class_weights: bool = False, + pretrained_model: str | None = None, + cuda: bool = False, + ngpu: int = 0, + output_exporters: list[OutputExporter] | None = None, + ): """Class from which the network is trained, evaluated and tested. Args: @@ -69,8 +68,7 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 self.neuralnet = neuralnet self.pretrained_model = pretrained_model - self._init_datasets(dataset_train, dataset_val, dataset_test, - val_size, test_size) + self._init_datasets(dataset_train, dataset_val, dataset_test, val_size, test_size) self.cuda = cuda self.ngpu = ngpu @@ -87,14 +85,16 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 and that you are running on GPUs.\n --> To turn CUDA off set cuda=False in Trainer.\n --> Aborting the experiment \n\n' - """) + """ + ) raise ValueError( """ --> CUDA not detected: Make sure that CUDA is installed and that you are running on GPUs.\n --> To turn CUDA off set cuda=False in Trainer.\n --> Aborting the experiment \n\n' - """) + """ + ) else: self.device = torch.device("cpu") if self.ngpu > 0: @@ -103,16 +103,18 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 --> CUDA not detected. Set cuda=True in Trainer to turn CUDA on.\n --> Aborting the experiment \n\n - """) + """ + ) raise ValueError( """ --> CUDA not detected. Set cuda=True in Trainer to turn CUDA on.\n --> Aborting the experiment \n\n - """) + """ + ) _log.info(f"Device set to {self.device}.") - if self.device.type == 'cuda': + if self.device.type == "cuda": _log.info(f"CUDA device name is {torch.cuda.get_device_name(0)}.") _log.info(f"Number of GPUs set to {self.ngpu}.") @@ -144,7 +146,7 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 # clustering the datasets if self.clustering_method is not None: - if self.clustering_method in ('mcl', 'louvain'): + if self.clustering_method in ("mcl", "louvain"): _log.info("Loading clusters") self._precluster(self.dataset_train) @@ -152,15 +154,12 @@ def __init__( # pylint: disable=too-many-arguments # noqa: MC0001 self._precluster(self.dataset_val) else: _log.warning("No validation dataset given. Randomly splitting training set in training set and validation set.") - self.dataset_train, self.dataset_val = _divide_dataset( - self.dataset_train, splitsize=self.val_size) + self.dataset_train, self.dataset_val = _divide_dataset(self.dataset_train, splitsize=self.val_size) if self.dataset_test is not None: self._precluster(self.dataset_test) else: - raise ValueError( - f"Invalid node clustering method: {self.clustering_method}\n\t" - "Please set clustering_method to 'mcl', 'louvain' or None. Default to 'mcl' \n\t") + raise ValueError(f"Invalid node clustering method: {self.clustering_method}. Please set clustering_method to 'mcl', 'louvain' or None.") else: if self.neuralnet is None: @@ -181,9 +180,9 @@ def _init_output_exporters(self, output_exporters: list[OutputExporter] | None): if output_exporters is not None: self._output_exporters = OutputExporterCollection(*output_exporters) else: - self._output_exporters = OutputExporterCollection(HDF5OutputExporter('./output')) + self._output_exporters = OutputExporterCollection(HDF5OutputExporter("./output")) - def _init_datasets( # pylint: disable=too-many-arguments + def _init_datasets( self, dataset_train: GraphDataset | GridDataset, dataset_val: GraphDataset | GridDataset | None, @@ -191,7 +190,6 @@ def _init_datasets( # pylint: disable=too-many-arguments val_size: int | float | None, test_size: int | float | None, ): - self._check_dataset_equivalence(dataset_train, dataset_val, dataset_test) self.dataset_train = dataset_train @@ -214,7 +212,6 @@ def _init_datasets( # pylint: disable=too-many-arguments _log.warning("Validation dataset was provided to Trainer; val_size parameter is ignored.") def _init_from_dataset(self, dataset: GraphDataset | GridDataset): - if isinstance(dataset, GraphDataset): self.clustering_method = dataset.clustering_method self.node_features = dataset.node_features @@ -243,14 +240,12 @@ def _init_from_dataset(self, dataset: GraphDataset | GridDataset): def _load_model(self): """Loads the neural network model.""" - self._put_model_to_device(self.dataset_train) self.configure_optimizers() self.set_lossfunction() def _check_dataset_equivalence(self, dataset_train, dataset_val, dataset_test): """Check dataset_train type and train_source parameter settings.""" - # dataset_train is None when pretrained_model is set if dataset_train is None: # only check the test dataset @@ -259,49 +254,50 @@ def _check_dataset_equivalence(self, dataset_train, dataset_val, dataset_test): else: # Make sure train dataset has valid type if not isinstance(dataset_train, GraphDataset) and not isinstance(dataset_train, GridDataset): - raise TypeError(f"""train dataset is not the right type {type(dataset_train)} - Make sure it's either GraphDataset or GridDataset""") + raise TypeError(f"train dataset is not the right type {type(dataset_train)}. Make sure it's either GraphDataset or GridDataset") if dataset_val is not None: - self._check_dataset_value(dataset_train, dataset_val, type_dataset = "valid") + self._check_dataset_value( + dataset_train, + dataset_val, + type_dataset="valid", + ) if dataset_test is not None: - self._check_dataset_value(dataset_train, dataset_test, type_dataset = "test") + self._check_dataset_value( + dataset_train, + dataset_test, + type_dataset="test", + ) def _check_dataset_value(self, dataset_train, dataset_check, type_dataset): """Check valid/test dataset settings.""" - # Check train_source parameter in valid/test is set. if dataset_check.train_source is None: - raise ValueError(f"""{type_dataset} dataset has train_source parameter set to None. - Make sure to set it as a valid training data source.""") + raise ValueError(f"{type_dataset} dataset has train_source parameter set to None. Make sure to set it as a valid training data source.") # Check train_source parameter in valid/test is equivalent to train which passed to Trainer. if dataset_check.train_source != dataset_train: - raise ValueError(f"""{type_dataset} dataset has different train_source parameter compared to the one given in Trainer. - Make sure to assign equivalent train_source in Trainer""") + raise ValueError( + f"{type_dataset} dataset has different train_source parameter from Trainer. Make sure to assign equivalent train_source in Trainer." + ) def _load_pretrained_model(self): - """ - Loads pretrained model - """ - - self.test_loader = DataLoader( - self.dataset_test, - pin_memory=self.cuda) + """Loads pretrained model.""" + self.test_loader = DataLoader(self.dataset_test, pin_memory=self.cuda) _log.info("Testing set loaded\n") self._put_model_to_device(self.dataset_test) # load the model and the optimizer state - self.optimizer = self.optimizer(self.model.parameters(), lr=self.lr, weight_decay = self.weight_decay) + self.optimizer = self.optimizer( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) self.optimizer.load_state_dict(self.opt_loaded_state_dict) self.model.load_state_dict(self.model_load_state_dict) def _precluster(self, dataset: GraphDataset): - """Pre-clusters nodes of the graphs - - Args: - dataset (:class:`GraphDataset`) - """ + """Pre-clusters nodes of the graphs.""" for fname, mol in tqdm(dataset.index_entries): data = dataset.load_one_graph(fname, mol) @@ -310,7 +306,7 @@ def _precluster(self, dataset: GraphDataset): try: _log.info(f"deleting {mol}") del f5[mol] - except BaseException: + except BaseException: # noqa: BLE001 (blind-except) _log.info(f"{mol} not found") f5.close() continue @@ -323,21 +319,17 @@ def _precluster(self, dataset: GraphDataset): del clust_grp[self.clustering_method.lower()] method_grp = clust_grp.create_group(self.clustering_method.lower()) - cluster = community_detection( - data.edge_index, data.num_nodes, method=self.clustering_method - ) + cluster = community_detection(data.edge_index, data.num_nodes, method=self.clustering_method) method_grp.create_dataset("depth_0", data=cluster.cpu()) data = community_pooling(cluster, data) - cluster = community_detection( - data.edge_index, data.num_nodes, method=self.clustering_method - ) + cluster = community_detection(data.edge_index, data.num_nodes, method=self.clustering_method) method_grp.create_dataset("depth_1", data=cluster.cpu()) f5.close() def _put_model_to_device(self, dataset: GraphDataset | GridDataset): """ - Puts the model on the available device + Puts the model on the available device. Args: dataset (:class:`GraphDataset` | :class:`GridDataset`): GraphDataset object. @@ -345,7 +337,6 @@ def _put_model_to_device(self, dataset: GraphDataset | GridDataset): Raises: ValueError: Incorrect output shape """ - # regression mode if self.task == targets.REGRESS: self.output_shape = 1 @@ -361,22 +352,15 @@ def _put_model_to_device(self, dataset: GraphDataset | GridDataset): target_shape = None if isinstance(dataset, GraphDataset): - num_node_features = dataset.get(0).num_features num_edge_features = len(dataset.edge_features) - self.model = self.neuralnet( - num_node_features, - self.output_shape, - num_edge_features - ).to(self.device) + self.model = self.neuralnet(num_node_features, self.output_shape, num_edge_features).to(self.device) elif isinstance(dataset, GridDataset): _, num_features, box_width, box_height, box_depth = dataset.get(0).x.shape - self.model = self.neuralnet(num_features, - (box_width, box_height, box_depth) - ).to(self.device) + self.model = self.neuralnet(num_features, (box_width, box_height, box_depth)).to(self.device) else: raise TypeError(type(dataset)) @@ -388,12 +372,18 @@ def _put_model_to_device(self, dataset: GraphDataset | GridDataset): # check for compatibility for output_exporter in self._output_exporters: if not output_exporter.is_compatible_with(self.output_shape, target_shape): - raise ValueError(f"""output exporter of type {type(output_exporter)}\n - is not compatible with output shape {self.output_shape}\n - and target shape {target_shape}""") - - def configure_optimizers(self, optimizer = None, lr: float = 0.001, weight_decay: float = 1e-05): + raise ValueError( + f"Output exporter of type {type(output_exporter)}\n\t" + f"is not compatible with output shape {self.output_shape}\n\t" + f"and target shape {target_shape}." + ) + def configure_optimizers( + self, + optimizer: torch.optim = None, + lr: float = 0.001, + weight_decay: float = 1e-05, + ): """ Configure optimizer and its main parameters. @@ -406,7 +396,6 @@ def configure_optimizers(self, optimizer = None, lr: float = 0.001, weight_decay weight_decay (float, optional): Weight decay (L2 penalty). Weight decay is fundamental for GNNs, otherwise, parameters can become too big and the gradient may explode. Defaults to 1e-05. """ - self.lr = lr self.weight_decay = weight_decay @@ -414,14 +403,17 @@ def configure_optimizers(self, optimizer = None, lr: float = 0.001, weight_decay self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) else: try: - self.optimizer = optimizer(self.model.parameters(), lr = lr, weight_decay = weight_decay) + self.optimizer = optimizer(self.model.parameters(), lr=lr, weight_decay=weight_decay) except Exception as e: _log.error(e) _log.info("Invalid optimizer. Please use only optimizers classes from torch.optim package.") - raise e - - def set_lossfunction(self, lossfunction = None, override_invalid: bool = False): #pylint: disable=too-many-locals # noqa: MC0001 + raise + def set_lossfunction( + self, + lossfunction=None, + override_invalid: bool = False, + ): """ Set the loss function. @@ -438,18 +430,21 @@ def set_lossfunction(self, lossfunction = None, override_invalid: bool = False): invalid for the task do no longer automaticallt raise an exception. Defaults to False. """ - default_regression_loss = nn.MSELoss default_classification_loss = nn.CrossEntropyLoss def _invalid_loss(): if override_invalid: - _log.warning(f'The provided loss function ({lossfunction}) is not appropriate for {self.task} tasks.\n\t' + - 'You have set override_invalid to True, so the training will run with this loss function nonetheless.\n\t' + - 'This will likely cause other errors or exceptions down the line.') + _log.warning( + f"The provided loss function ({lossfunction}) is not appropriate for {self.task} tasks.\n\t" + "You have set override_invalid to True, so the training will run with this loss function nonetheless.\n\t" + "This will likely cause other errors or exceptions down the line." + ) else: - invalid_loss_error = (f'The provided loss function ({lossfunction}) is not appropriate for {self.task} tasks.\n\t' + - 'If you want to use this loss function anyway, set override_invalid to True.') + invalid_loss_error = ( + f"The provided loss function ({lossfunction}) is not appropriate for {self.task} tasks.\n\t" + "If you want to use this loss function anyway, set override_invalid to True." + ) _log.error(invalid_loss_error) raise ValueError(invalid_loss_error) @@ -465,34 +460,37 @@ def _invalid_loss(): if self.task == targets.REGRESS: if lossfunction is None: lossfunction = default_regression_loss - _log.info(f'No loss function provided, the default loss function for {self.task} tasks is used: {lossfunction}') - else: - if custom_loss: - custom_loss_warning = ( f'The provided loss function ({lossfunction}) is not part of the default list.\n\t' + - f'Please ensure that this loss function is appropriate for {self.task} tasks.\n\t') - _log.warning(custom_loss_warning) - elif lossfunction not in losses.regression_losses: - _invalid_loss() + _log.info(f"No loss function provided, the default loss function for {self.task} tasks is used: {lossfunction}") + elif custom_loss: + custom_loss_warning = ( + f"The provided loss function ({lossfunction}) is not part of the default list.\n\t" + f"Please ensure that this loss function is appropriate for {self.task} tasks.\n\t" + ) + _log.warning(custom_loss_warning) + elif lossfunction not in losses.regression_losses: + _invalid_loss() self.lossfunction = lossfunction() # Set classification loss elif self.task == targets.CLASSIF: if lossfunction is None: lossfunction = default_classification_loss - _log.info(f'No loss function provided, the default loss function for {self.task} tasks is used: {lossfunction}') - else: - if custom_loss: - custom_loss_warning = ( f'The provided loss function ({lossfunction}) is not part of the default list.\n\t' + - f'Please ensure that this loss function is appropriate for {self.task} tasks.\n\t') - _log.warning(custom_loss_warning) - elif lossfunction not in losses.classification_losses: - _invalid_loss() + _log.info(f"No loss function provided, the default loss function for {self.task} tasks is used: {lossfunction}") + elif custom_loss: + custom_loss_warning = ( + f"The provided loss function ({lossfunction}) is not part of the default list.\n\t" + f"Please ensure that this loss function is appropriate for {self.task} tasks.\n\t" + ) + _log.warning(custom_loss_warning) + elif lossfunction not in losses.classification_losses: + _invalid_loss() + if not self.class_weights: self.lossfunction = lossfunction() else: - self.lossfunction = lossfunction # weights will be set in the train() method + self.lossfunction = lossfunction # weights will be set in the train() method - def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-locals # noqa: MC0001 + def train( # noqa: PLR0915 (too-many-statements) self, nepoch: int = 1, batch_size: int = 32, @@ -503,7 +501,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc validate: bool = False, num_workers: int = 0, best_model: bool = True, - filename: str | None = 'model.pth.tar' + filename: str | None = "model.pth.tar", ): """ Performs the training of the model. @@ -544,7 +542,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc batch_size=self.batch_size_train, shuffle=self.shuffle, num_workers=num_workers, - pin_memory=self.cuda + pin_memory=self.cuda, ) _log.info("Training set loaded\n") @@ -554,26 +552,23 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc batch_size=self.batch_size_train, shuffle=self.shuffle, num_workers=num_workers, - pin_memory=self.cuda + pin_memory=self.cuda, ) _log.info("Validation set loaded\n") else: self.valid_loader = None _log.info("No validation set provided\n") _log.warning( - "Training data will be used both for learning and model selection, which may lead to overfitting." + - "\nIt is usually preferable to use a validation set during the training phase.") + "Training data will be used both for learning and model selection, which may lead to overfitting.\n" + "It is usually preferable to use a validation set during the training phase." + ) # Assign weights to each class if self.task == targets.CLASSIF and self.class_weights: - targets_all = [] - for batch in self.train_loader: - targets_all.append(batch.y) + targets_all = [batch.y for batch in self.train_loader] targets_all = torch.cat(targets_all).squeeze().tolist() - self.weights = torch.tensor( - [targets_all.count(i) for i in self.classes], dtype=torch.float32 - ) + self.weights = torch.tensor([targets_all.count(i) for i in self.classes], dtype=torch.float32) _log.info(f"class occurences: {self.weights}") self.weights = 1.0 / self.weights self.weights = self.weights / self.weights.sum() @@ -582,8 +577,10 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc try: self.lossfunction = self.lossfunction(weight=self.weights.to(self.device)) # Check whether loss allows for weighted classes except TypeError as e: - weight_error = (f"Loss function {self.lossfunction} does not allow for weighted classes.\n\t" + - "Please use a different loss function or set class_weights to False.\n") + weight_error = ( + f"Loss function {self.lossfunction} does not allow for weighted classes.\n\t" + "Please use a different loss function or set class_weights to False.\n" + ) _log.error(weight_error) raise ValueError(weight_error) from e else: @@ -594,14 +591,19 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc saved_model = False if earlystop_patience or earlystop_maxgap: - early_stopping = EarlyStopping(patience=earlystop_patience, maxgap=earlystop_maxgap, min_epoch=min_epoch, trace_func=_log.info) + early_stopping = EarlyStopping( + patience=earlystop_patience, + maxgap=earlystop_maxgap, + min_epoch=min_epoch, + trace_func=_log.info, + ) else: early_stopping = None with self._output_exporters: # Number of epochs self.nepoch = nepoch - _log.info('Epoch 0:') + _log.info("Epoch 0:") self._eval(self.train_loader, 0, "training") if validate: if self.valid_loader is None: @@ -610,7 +612,7 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc # Loop over epochs for epoch in range(1, nepoch + 1): - _log.info(f'Epoch {epoch}:') + _log.info(f"Epoch {epoch}:") # Set the module in training mode self.model.train() @@ -621,12 +623,11 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc if validate: loss_ = self._eval(self.valid_loader, epoch, "validation") valid_losses.append(loss_) - if best_model: - if min(valid_losses) == loss_: - checkpoint_model = self._save_model() - saved_model = True - self.epoch_saved_model = epoch - _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') + if best_model and min(valid_losses) == loss_: + checkpoint_model = self._save_model() + saved_model = True + self.epoch_saved_model = epoch + _log.info(f"Best model saved at epoch # {self.epoch_saved_model}.") # check early stopping criteria (in validation case only) if early_stopping: # compare last validation and training loss @@ -634,24 +635,23 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc if early_stopping.early_stop: break - else: - # if no validation set, save the best performing model on the training set - if best_model: - if min(train_losses) == loss_: - checkpoint_model = self._save_model() - saved_model = True - self.epoch_saved_model = epoch - _log.info(f'Best model saved at epoch # {self.epoch_saved_model}.') + elif best_model: # if no validation set, save the best performing model on the training set + if min(train_losses) == loss_: + checkpoint_model = self._save_model() + saved_model = True + self.epoch_saved_model = epoch + _log.info(f"Best model saved at epoch # {self.epoch_saved_model}.") # Save the last model if best_model is False or not saved_model: checkpoint_model = self._save_model() self.epoch_saved_model = epoch - _log.info(f'Last model saved at epoch # {self.epoch_saved_model}.') + _log.info(f"Last model saved at epoch # {self.epoch_saved_model}.") if not saved_model: - warnings.warn("A model has been saved but the validation and/or the training losses were NaN;" + - "\n\ttry to increase the cutoff distance during the data processing or the number of data points " + - "during the training.") + warnings.warn( + "A model has been saved but the validation and/or the training losses were NaN;\n\t" + "try to increase the cutoff distance during the data processing or the number of data points during the training." + ) # Now that the training loop is over, save the model if filename: @@ -663,16 +663,15 @@ def train( # pylint: disable=too-many-arguments, too-many-branches, too-many-loc def _epoch(self, epoch_number: int, pass_name: str) -> float | None: """ - Runs a single epoch + Runs a single epoch. Args: - epoch_number (int) + epoch_number (int): the current epoch number pass_name (str): 'training', 'validation' or 'testing' Returns: Running loss. """ - sum_of_losses = 0 count_predictions = 0 target_vals = [] @@ -681,7 +680,7 @@ def _epoch(self, epoch_number: int, pass_name: str) -> float | None: t0 = time() for data_batch in self.train_loader: if self.cuda: - data_batch = data_batch.to(self.device, non_blocking=True) + data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 (redefined-loop-name) self.optimizer.zero_grad() pred = self.model(data_batch) pred, data_batch.y = self._format_output(pred, data_batch.y) @@ -712,20 +711,25 @@ def _epoch(self, epoch_number: int, pass_name: str) -> float | None: epoch_loss = None self._output_exporters.process( - pass_name, epoch_number, entry_names, outputs, target_vals, epoch_loss) + pass_name, + epoch_number, + entry_names, + outputs, + target_vals, + epoch_loss, + ) self._log_epoch_data(pass_name, epoch_loss, dt) return epoch_loss - def _eval( # pylint: disable=too-many-locals - self, - loader: DataLoader, - epoch_number: int, - pass_name: str - ) -> float | None: - + def _eval( + self, + loader: DataLoader, + epoch_number: int, + pass_name: str, + ) -> float | None: """ - Evaluates the model + Evaluates the model. Args: loader (Dataloader): Data to evaluate on. @@ -735,7 +739,6 @@ def _eval( # pylint: disable=too-many-locals Returns: Running loss. """ - # Sets the module in evaluation mode self.model.eval() loss_func = self.lossfunction @@ -747,7 +750,7 @@ def _eval( # pylint: disable=too-many-locals t0 = time() for data_batch in loader: if self.cuda: - data_batch = data_batch.to(self.device, non_blocking=True) + data_batch = data_batch.to(self.device, non_blocking=True) # noqa: PLW2901 (redefined-loop-name) pred = self.model(data_batch) pred, y = self._format_output(pred, data_batch.y) @@ -779,7 +782,13 @@ def _eval( # pylint: disable=too-many-locals eval_loss = None self._output_exporters.process( - pass_name, epoch_number, entry_names, outputs, target_vals, eval_loss) + pass_name, + epoch_number, + entry_names, + outputs, + target_vals, + eval_loss, + ) self._log_epoch_data(pass_name, eval_loss, dt) return eval_loss @@ -787,40 +796,38 @@ def _eval( # pylint: disable=too-many-locals @staticmethod def _log_epoch_data(stage: str, loss: float, time: float): """ - Prints the data of each epoch + Prints the data of each epoch. Args: stage (str): Train or valid. loss (float): Loss during that epoch. time (float): Timing of the epoch. """ - _log.info(f'{stage} loss {loss} | time {time}') + _log.info(f"{stage} loss {loss} | time {time}") def _format_output(self, pred, target=None): - - """ - Format the network output depending on the task (classification/regression). - """ - + """Format the network output depending on the task (classification/regression).""" if (self.task == targets.CLASSIF) and (target is not None): # For categorical cross entropy, the target must be a one-dimensional tensor # of class indices with type long and the output should have raw, unnormalized values - target = torch.tensor( - [self.classes_to_index[x] if isinstance(x, str) else self.classes_to_index[int(x)] for x in target] - ) - if isinstance(self.lossfunction, (nn.BCELoss, nn.BCEWithLogitsLoss)): + target = torch.tensor([self.classes_to_index[x] if isinstance(x, str) else self.classes_to_index[int(x)] for x in target]) + if isinstance(self.lossfunction, nn.BCELoss | nn.BCEWithLogitsLoss): # # pred must be in (0,1) range and target must be float with same shape as pred # pred = F.softmax(pred) # target = torch.tensor( # [[0,1] if x == [1] else [1,0] for x in target] # ).float() - raise ValueError('BCELoss and BCEWithLogitsLoss are currently not supported.\n\t' + - 'For further details see: https://github.com/DeepRank/deeprank2/issues/318') + raise ValueError( + "BCELoss and BCEWithLogitsLoss are currently not supported.\n\t" + "For further details see: https://github.com/DeepRank/deeprank2/issues/318" + ) if isinstance(self.lossfunction, losses.classification_losses) and not isinstance(self.lossfunction, losses.classification_tested): - raise ValueError(f'{self.lossfunction} is currently not supported.\n\t' + - f'Supported loss functions for classification: {losses.classification_tested}.\n\t' + - 'Implementation of other loss functions requires adaptation of Trainer._format_output.') + raise ValueError( + f"{self.lossfunction} is currently not supported.\n\t" + f"Supported loss functions for classification: {losses.classification_tested}.\n\t" + "Implementation of other loss functions requires adaptation of Trainer._format_output." + ) elif self.task == targets.REGRESS: pred = pred.reshape(-1) @@ -833,7 +840,8 @@ def _format_output(self, pred, target=None): def test( self, batch_size: int = 32, - num_workers: int = 0): + num_workers: int = 0, + ): """ Performs the testing of the model. @@ -844,11 +852,7 @@ def test( Defaults to 0. """ if (not self.pretrained_model) and (not self.model_load_state_dict): - raise ValueError( - """ - No pretrained model provided and no training performed. - Please provide a pretrained model or train the model before testing.\n - """) + raise ValueError("No pretrained model provided and no training performed. Please provide a pretrained model or train the model before testing.") self.batch_size_test = batch_size @@ -859,7 +863,7 @@ def test( self.dataset_test, batch_size=self.batch_size_test, num_workers=num_workers, - pin_memory=self.cuda + pin_memory=self.cuda, ) _log.info("Testing set loaded\n") else: @@ -867,19 +871,15 @@ def test( raise ValueError("No test dataset provided.") with self._output_exporters: - # Run test self._eval(self.test_loader, self.epoch_saved_model, "testing") def _load_params(self): - """ - Loads the parameters of a pretrained model - """ - + """Loads the parameters of a pretrained model.""" if torch.cuda.is_available(): state = torch.load(self.pretrained_model) else: - state = torch.load(self.pretrained_model, map_location=torch.device('cpu')) + state = torch.load(self.pretrained_model, map_location=torch.device("cpu")) self.data_type = state["data_type"] self.model_load_state_dict = state["model_state"] @@ -921,12 +921,12 @@ def _save_model(self): features_transform_to_save = copy.deepcopy(self.features_transform) # prepare transform dictionary for being saved if features_transform_to_save: - for _, key in features_transform_to_save.items(): - if key['transform'] is None: + for key in features_transform_to_save.values(): + if key["transform"] is None: continue - str_expr = inspect.getsource(key['transform']) - match = re.search(r'\'transform\':.*(lambda.*).*,.*\'standardize\'.*', str_expr).group(1) - key['transform'] = match + str_expr = inspect.getsource(key["transform"]) + match = re.search(r"\'transform\':.*(lambda.*).*,.*\'standardize\'.*", str_expr).group(1) + key["transform"] = match state = { "data_type": self.data_type, @@ -957,48 +957,47 @@ def _save_model(self): "means": self.means, "devs": self.devs, "cuda": self.cuda, - "ngpu": self.ngpu + "ngpu": self.ngpu, } - return state + return state # noqa:RET504 (unnecessary-assign) def _divide_dataset( dataset: GraphDataset | GridDataset, splitsize: float | int | None = None, ) -> tuple[GraphDataset, GraphDataset] | tuple[GridDataset, GridDataset]: - - """Divides the dataset into a training set and an evaluation set + """Divides the dataset into a training set and an evaluation set. Args: dataset (:class:`GraphDataset` | :class:`GridDataset`): Input dataset to be split into training and validation data. splitsize (float | int | None, optional): Fraction of dataset (if float) or number of datapoints (if int) to use for validation. Defaults to None. """ - if splitsize is None: splitsize = 0.25 full_size = len(dataset) # find number of datapoints to include in training dataset - if isinstance (splitsize, float): + if isinstance(splitsize, float): n_split = int(splitsize * full_size) - elif isinstance (splitsize, int): + elif isinstance(splitsize, int): n_split = splitsize else: - raise TypeError (f"type(splitsize) must be float, int or None ({type(splitsize)} detected.)") + raise TypeError(f"type(splitsize) must be float, int or None ({type(splitsize)} detected.)") # raise exception if no training data or negative validation size if n_split >= full_size or n_split < 0: - raise ValueError (f"invalid splitsize: {n_split}\n\t" + - f"splitsize must be a float between 0 and 1 OR an int smaller than the size of the dataset ({full_size} datapoints)") + raise ValueError( + f"Invalid splitsize: {n_split}. splitsize must be a float between 0 and 1 OR an int smaller than the size of the dataset ({full_size} datapoints)" + ) if splitsize == 0: # i.e. the fraction of splitsize was so small that it rounded to <1 datapoint dataset_main = dataset dataset_split = None else: indices = np.arange(full_size) - np.random.shuffle(indices) + np.random.default_rng().shuffle(indices) dataset_main = copy.deepcopy(dataset) dataset_main.index_entries = [dataset.index_entries[i] for i in indices[n_split:]] diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index 83511d05a..1e49377bd 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -21,19 +21,17 @@ def _add_atom_to_residue(atom: Atom, residue: Residue): If no matching atom is found, add the current atom to the residue. If there's another atom with the same name, choose the one with the highest occupancy. """ - for other_atom in residue.atoms: - if other_atom.name == atom.name: - if other_atom.occupancy < atom.occupancy: - other_atom.change_altloc(atom) - return + if other_atom.name == atom.name and other_atom.occupancy < atom.occupancy: + other_atom.change_altloc(atom) + return residue.add_atom(atom) def _add_atom_data_to_structure( structure: PDBStructure, pdb_obj: pdb2sql_object, - **kwargs + **kwargs, ): """This subroutine retrieves pdb2sql atomic data for `PDBStructure` objects as defined in DeepRank2. @@ -41,14 +39,13 @@ def _add_atom_data_to_structure( Args: structure (:class:`PDBStructure`): The structure to which this atom should be added to. - pdb (pdb2sql_object): The `pdb2sql` object to retrieve the data from. + pdb_obj (pdb2sql_object): The `pdb2sql` object to retrieve the data from. kwargs: as required by the get function for the `pdb2sql` object. """ - pdb2sql_columns = "x,y,z,name,altLoc,occ,element,chainID,resSeq,resName,iCode" - data_keys = pdb2sql_columns.split(sep=',') + data_keys = pdb2sql_columns.split(sep=",") for data_values in pdb_obj.get(pdb2sql_columns, **kwargs): - atom_data = dict(zip(data_keys, data_values)) + atom_data = dict(zip(data_keys, data_values, strict=True)) # exit function if this atom is already part of the structure if atom_data["altLoc"] not in (None, "", "A"): @@ -84,7 +81,7 @@ def get_structure(pdb_obj: pdb2sql_object, id_: str) -> PDBStructure: """Builds a structure from rows in a pdb file. Args: - pdb (pdb2sql object): The pdb structure that we're investigating. + pdb_obj (pdb2sql object): The pdb structure that we're investigating. id_ (str): Unique id for the pdb structure. Returns: @@ -98,10 +95,9 @@ def get_structure(pdb_obj: pdb2sql_object, id_: str) -> PDBStructure: def get_contact_atoms( pdb_path: str, chain_ids: list[str], - influence_radius: float + influence_radius: float, ) -> list[Atom]: """Gets the contact atoms from pdb2sql and wraps them in python objects.""" - interface = pdb2sql_interface(pdb_path) pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] structure = PDBStructure(f"contact_atoms_{pdb_name}") @@ -115,7 +111,7 @@ def get_contact_atoms( pdb_rowID = atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] _add_atom_data_to_structure(structure, interface, rowID=pdb_rowID) finally: - interface._close() # pylint: disable=protected-access + interface._close() # noqa: SLF001 (private-member-access) return structure.get_atoms() @@ -139,7 +135,6 @@ def get_residue_contact_pairs( Returns: list[Pair]: The pairs of contacting residues. """ - # Find out which residues are pairs interface = pdb2sql_interface(pdb_path) try: @@ -150,10 +145,10 @@ def get_residue_contact_pairs( return_contact_pairs=True, ) finally: - interface._close() # pylint: disable=protected-access + interface._close() # noqa: SLF001 (private-member-access) # Map to residue objects - residue_pairs = set([]) + residue_pairs = set() for residue_key1, residue_contacts in contact_residues.items(): residue1 = _get_residue_from_key(structure, residue_key1) for residue_key2 in residue_contacts: @@ -168,20 +163,13 @@ def _get_residue_from_key( residue_key: tuple[str, int, str], ) -> Residue: """Returns a residue object given a pdb2sql-formatted residue key.""" - residue_chain_id, residue_number, residue_name = residue_key chain = structure.get_chain(residue_chain_id) for residue in chain.residues: - if ( - residue.number == residue_number - and residue.amino_acid is not None - and residue.amino_acid.three_letter_code == residue_name - ): + if residue.number == residue_number and residue.amino_acid is not None and residue.amino_acid.three_letter_code == residue_name: return residue - raise ValueError( - f"Residue ({residue_key}) not found in {structure.id}." - ) + raise ValueError(f"Residue ({residue_key}) not found in {structure.id}.") def get_surrounding_residues( @@ -199,13 +187,16 @@ def get_surrounding_residues( Returns: list[:class:`Residue`]: The surrounding residues. """ - structure_atoms = structure.get_atoms() structure_atom_positions = [atom.position for atom in structure_atoms] residue_atom_positions = [atom.position for atom in residue.atoms] - pairwise_distances = distance_matrix(structure_atom_positions, residue_atom_positions, p=2) + pairwise_distances = distance_matrix( + structure_atom_positions, + residue_atom_positions, + p=2, + ) - surrounding_residues = set([]) + surrounding_residues = set() for structure_atom_index, structure_atom in enumerate(structure_atoms): shortest_distance = np.min(pairwise_distances[structure_atom_index, :]) if shortest_distance < radius: diff --git a/deeprank2/utils/community_pooling.py b/deeprank2/utils/community_pooling.py index 802617651..14b8e1de7 100644 --- a/deeprank2/utils/community_pooling.py +++ b/deeprank2/utils/community_pooling.py @@ -13,27 +13,30 @@ def plot_graph(graph, cluster): - pos = nx.spring_layout(graph, iterations=200) nx.draw(graph, pos, node_color=cluster) plt.show() def get_preloaded_cluster(cluster, batch): - nbatch = torch.max(batch) + 1 for ib in range(1, nbatch): cluster[batch == ib] += torch.max(cluster[batch == ib - 1]) + 1 return cluster -def community_detection_per_batch( # pylint: disable=too-many-locals - edge_index, batch, num_nodes: int, edge_attr=None, method: str = "mcl" +def community_detection_per_batch( + edge_index, + batch, + num_nodes: int, + edge_attr=None, + method: str = "mcl", ): """Detects clusters of nodes based on the edge attributes (distances). Args: edge_index (Tensor): Edge index. + batch (?): ? num_nodes (int): Number of nodes. edge_attr (Tensor, optional): Edge attributes. Defaults to None. method (str, optional): Method. Defaults to "mcl". @@ -44,7 +47,6 @@ def community_detection_per_batch( # pylint: disable=too-many-locals Returns: cluster Tensor """ - # make the networkX graph g = nx.Graph() g.add_nodes_from(range(num_nodes)) @@ -60,7 +62,6 @@ def community_detection_per_batch( # pylint: disable=too-many-locals cluster, ncluster = [], 0 for ib in range(num_batch): - index = torch.tensor(all_index)[batch == ib].tolist() subg = g.subgraph(index) @@ -90,7 +91,12 @@ def community_detection_per_batch( # pylint: disable=too-many-locals return torch.tensor(cluster).to(device) -def community_detection(edge_index, num_nodes: int, edge_attr=None, method: str = "mcl"): # pylint: disable=too-many-locals +def community_detection( + edge_index, + num_nodes: int, + edge_attr=None, + method: str = "mcl", +): """Detects clusters of nodes based on the edge attributes (distances). Args: @@ -106,7 +112,6 @@ def community_detection(edge_index, num_nodes: int, edge_attr=None, method: str cluster Tensor Examples: - >>> import torch >>> from torch_geometric.data import Data, Batch >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], @@ -137,7 +142,6 @@ def community_detection(edge_index, num_nodes: int, edge_attr=None, method: str # detect the communities using MCL detection if method == "mcl": - matrix = nx.to_scipy_sparse_array(g) # run MCL with default parameters @@ -170,7 +174,6 @@ def community_pooling(cluster, data): pooled features tensor Examples: - >>> import torch >>> from torch_geometric.data import Data, Batch >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], @@ -184,7 +187,6 @@ def community_pooling(cluster, data): >>> cluster = community_detection(batch.edge_index, batch.num_nodes) >>> new_batch = community_pooling(cluster, batch) """ - # determine what the batches has as attributes has_internal_edges = hasattr(data, "internal_edge_index") has_pos2d = hasattr(data, "pos2d") @@ -193,9 +195,9 @@ def community_pooling(cluster, data): if has_internal_edges: warnings.warn( - """Internal edges are not supported anymore. - You should probably prepare the hdf5 file with - a more up to date version of this software.""", DeprecationWarning) + """Internal edges are not supported anymore. Please prepare the hdf5 file with a more up to date version of this software.""", + DeprecationWarning, + ) cluster, perm = consecutive_cluster(cluster) cluster = cluster.to(data.x.device) @@ -218,9 +220,7 @@ def community_pooling(cluster, data): # pool batch if hasattr(data, "batch"): batch = None if data.batch is None else pool_batch(perm, data.batch) - data = Batch( - batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos - ) + data = Batch(batch=batch, x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos) if has_cluster: data.cluster0 = c0 diff --git a/deeprank2/utils/earlystopping.py b/deeprank2/utils/earlystopping.py index aa6acbaf5..02e9bffd7 100644 --- a/deeprank2/utils/earlystopping.py +++ b/deeprank2/utils/earlystopping.py @@ -1,8 +1,8 @@ -from typing import Callable +from collections.abc import Callable class EarlyStopping: - def __init__( # pylint: disable=too-many-arguments + def __init__( self, patience: int = 10, delta: float = 0, @@ -11,9 +11,9 @@ def __init__( # pylint: disable=too-many-arguments verbose: bool = True, trace_func: Callable = print, ): - """ - Terminate training if validation loss doesn't improve after a given patience - or if a maximum gap between validation and training loss is reached. + """Terminate training upon trigger. + + Triggered if validation loss doesn't improve after a given patience or if a maximum gap between validation and training loss is reached. Args: patience (int, optional): How long to wait after last time validation loss improved. @@ -29,7 +29,6 @@ def __init__( # pylint: disable=too-many-arguments trace_func (Callable, optional): Function used for recording EarlyStopping status. Defaults to print. """ - self.patience = patience self.delta = delta self.maxgap = maxgap @@ -55,17 +54,19 @@ def __call__(self, epoch, val_loss, train_loss=None): self.counter += 1 if self.verbose: if self.delta: - extra_trace = f'more than {self.delta} ' + extra_trace = f"more than {self.delta} " else: - extra_trace = '' - self.trace_func(f'Validation loss did not decrease {extra_trace}({self.val_loss_min:.6f} --> {val_loss:.6f}). '+ - f'EarlyStopping counter: {self.counter} out of {self.patience}') + extra_trace = "" + self.trace_func( + f"Validation loss did not decrease {extra_trace}({self.val_loss_min:.6f} --> {val_loss:.6f}). " + f"EarlyStopping counter: {self.counter} out of {self.patience}" + ) if self.counter >= self.patience: - self.trace_func(f'EarlyStopping activated at epoch # {epoch} because patience of {self.patience} has been reached.') + self.trace_func(f"EarlyStopping activated at epoch # {epoch} because patience of {self.patience} has been reached.") self.early_stop = True else: if self.verbose: - self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') + self.trace_func(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).") self.best_score = score self.counter = 0 @@ -79,13 +80,13 @@ def __call__(self, epoch, val_loss, train_loss=None): raise ValueError("Cannot compute gap because no train_loss is provided to EarlyStopping.") gap = val_loss - train_loss if gap > self.maxgap: - self.trace_func(f'EarlyStopping activated at epoch # {epoch} due to overfitting. ' + - f'The difference between validation and training loss of {gap} exceeds the maximum allowed ({self.maxgap})') + self.trace_func( + f"EarlyStopping activated at epoch # {epoch} due to overfitting. " + f"The difference between validation and training loss of {gap} exceeds the maximum allowed ({self.maxgap})" + ) self.early_stop = True - - # This module is modified from https://github.com/Bjarten/early-stopping-pytorch, under the following license: diff --git a/deeprank2/utils/exporters.py b/deeprank2/utils/exporters.py index b8ce1863b..a262059c7 100644 --- a/deeprank2/utils/exporters.py +++ b/deeprank2/utils/exporters.py @@ -4,7 +4,7 @@ from math import sqrt import pandas as pd -from matplotlib import pyplot +from matplotlib import pyplot as plt from sklearn.metrics import roc_auc_score from torch import argmax, tensor from torch.nn.functional import cross_entropy @@ -16,8 +16,7 @@ class OutputExporter: """The class implements a general exporter to be called when a neural network generates outputs.""" - def __init__(self, directory_path: str = None): - + def __init__(self, directory_path: str | None = None): if directory_path is None: directory_path = "./output" self._directory_path = directory_path @@ -26,24 +25,29 @@ def __init__(self, directory_path: str = None): os.makedirs(self._directory_path) def __enter__(self): - "overridable" + """Overridable.""" return self def __exit__(self, exception_type, exception, traceback): - "overridable" - pass # pylint: disable=unnecessary-pass + """Overridable.""" - def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: list[str], output_values: list, target_values: list, loss: float): - "the entry_names, output_values, target_values MUST have the same length" - pass # pylint: disable=unnecessary-pass + def process( + self, + pass_name: str, + epoch_number: int, + entry_names: list[str], + output_values: list, + target_values: list, + loss: float, + ): + """The entry_names, output_values, target_values MUST have the same length.""" - def is_compatible_with( # pylint: disable=unused-argument + def is_compatible_with( self, - output_data_shape: int, - target_data_shape: int | None = None, + output_data_shape: int, # noqa: ARG002 (unused argument) + target_data_shape: int | None = None, # noqa: ARG002 (unused argument) ) -> bool: - "true if this exporter can work with the given data shapes" + """True if this exporter can work with the given data shapes.""" return True @@ -63,10 +67,24 @@ def __exit__(self, exception_type, exception, traceback): for output_exporter in self._output_exporters: output_exporter.__exit__(exception_type, exception, traceback) - def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: list[str], output_values: list, target_values: list, loss: float): + def process( + self, + pass_name: str, + epoch_number: int, + entry_names: list[str], + output_values: list, + target_values: list, + loss: float, + ): for output_exporter in self._output_exporters: - output_exporter.process(pass_name, epoch_number, entry_names, output_values, target_values, loss) + output_exporter.process( + pass_name, + epoch_number, + entry_names, + output_values, + target_values, + loss, + ) def __iter__(self): return iter(self._output_exporters) @@ -93,12 +111,22 @@ def __enter__(self): def __exit__(self, exception_type, exception, traceback): self._writer.__exit__(exception_type, exception, traceback) - def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments, too-many-locals - entry_names: list[str], output_values: list, target_values: list, loss: float): - "write to tensorboard" - + def process( + self, + pass_name: str, + epoch_number: int, + entry_names: list[str], + output_values: list, + target_values: list, + loss: float, # noqa: ARG002 (unused argument) + ): + """Write to tensorboard.""" ce_loss = cross_entropy(tensor(output_values), tensor(target_values)).item() - self._writer.add_scalar(f"{pass_name} cross entropy loss", ce_loss, epoch_number) + self._writer.add_scalar( + f"{pass_name} cross entropy loss", + ce_loss, + epoch_number, + ) probabilities = [] fp, fn, tp, tn = 0, 0, 0, 0 @@ -115,10 +143,10 @@ def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many- elif prediction_value <= 0.0 and target_value <= 0.0: tn += 1 - elif prediction_value > 0.0 and target_value <= 0.0: # pylint: disable=chained-comparison + elif target_value <= 0.0 < prediction_value: fp += 1 - elif prediction_value <= 0.0 and target_value > 0.0: # pylint: disable=chained-comparison + elif prediction_value <= 0.0 < target_value: fn += 1 mcc_numerator = tn * tp - fp * fn @@ -145,19 +173,16 @@ def is_compatible_with( target_data_shape: int | None = None, ) -> bool: """For regression, target data is needed and output data must be a list of two-dimensional values.""" - return output_data_shape == 2 and target_data_shape == 1 class ScatterPlotExporter(OutputExporter): - """An output exporter that can make scatter plots, containing every single data point. + def __init__(self, directory_path: str, epoch_interval: int = 1): + """An output exporter that can make scatter plots, containing every single data point. - On the X-axis: targets values - On the Y-axis: output values - """ + On the X-axis: targets values + On the Y-axis: output values - def __init__(self, directory_path: str, epoch_interval: int = 1): - """ Args: directory_path (str): Where to store the plots. epoch_interval (int, optional): How often to make a plot, 5 means: every 5 epochs. Defaults to 1. @@ -178,41 +203,49 @@ def get_filename(self, epoch_number): @staticmethod def _get_color(pass_name): - pass_name = pass_name.lower().strip() - if pass_name in ("train", "training"): return "blue" - if pass_name in ("eval", "valid", "validation"): return "red" - if pass_name == ("test", "testing"): return "green" - return random.choice(["yellow", "cyan", "magenta"]) @staticmethod - def _plot(epoch_number: int, data: dict[str, tuple[list[float], list[float]]], png_path: str): - - pyplot.title(f"Epoch {epoch_number}") + def _plot( + epoch_number: int, + data: dict[str, tuple[list[float], list[float]]], + png_path: str, + ): + plt.title(f"Epoch {epoch_number}") for pass_name, (truth_values, prediction_values) in data.items(): - pyplot.scatter(truth_values, prediction_values, color=ScatterPlotExporter._get_color(pass_name), label=pass_name) + plt.scatter( + truth_values, + prediction_values, + color=ScatterPlotExporter._get_color(pass_name), + label=pass_name, + ) - pyplot.xlabel("truth") - pyplot.ylabel("prediction") + plt.xlabel("truth") + plt.ylabel("prediction") - pyplot.legend() - pyplot.savefig(png_path) - pyplot.close() + plt.legend() + plt.savefig(png_path) + plt.close() - def process(self, pass_name: str, epoch_number: int, # pylint: disable=too-many-arguments - entry_names: list[str], output_values: list, target_values: list, loss: float): + def process( + self, + pass_name: str, + epoch_number: int, + entry_names: list[str], # noqa: ARG002 (unused argument) + output_values: list, + target_values: list, + loss: float, # noqa: ARG002 (unused argument) + ): """Make the plot, if the epoch matches with the interval.""" - if epoch_number % self._epoch_interval == 0: - if epoch_number not in self._plot_data: self._plot_data[epoch_number] = {} @@ -227,7 +260,6 @@ def is_compatible_with( target_data_shape: int | None = None, ) -> bool: """For regression, target data is needed and output data must be a list of one-dimensional values.""" - return output_data_shape == 1 and target_data_shape == 1 @@ -248,44 +280,56 @@ class HDF5OutputExporter(OutputExporter): """ def __init__(self, directory_path: str): - self.phase = None super().__init__(directory_path) def __enter__(self): - - self.d = {'phase': [], 'epoch': [], 'entry': [], 'output': [], 'target': [], 'loss': []} + self.d = { + "phase": [], + "epoch": [], + "entry": [], + "output": [], + "target": [], + "loss": [], + } self.df = pd.DataFrame(data=self.d) return self def __exit__(self, exception_type, exception, traceback): - if self.phase is not None: if self.phase == "validation": self.phase = "training" self.df.to_hdf( - os.path.join(self._directory_path, 'output_exporter.hdf5'), + os.path.join(self._directory_path, "output_exporter.hdf5"), key=self.phase, - mode='a') + mode="a", + ) - def process( # pylint: disable=too-many-arguments + def process( self, pass_name: str, epoch_number: int, entry_names: list[str], output_values: list, target_values: list, - loss: float): - + loss: float, + ): self.phase = pass_name pass_name = [pass_name] * len(output_values) loss = [loss] * len(output_values) epoch_number = [epoch_number] * len(output_values) - d_epoch = {'phase': pass_name, 'epoch': epoch_number, 'entry': entry_names, 'output': output_values, 'target': target_values, 'loss': loss} + d_epoch = { + "phase": pass_name, + "epoch": epoch_number, + "entry": entry_names, + "output": output_values, + "target": target_values, + "loss": loss, + } df_epoch = pd.DataFrame(data=d_epoch) self.df = pd.concat([self.df, df_epoch]) - self.df.reset_index(drop=True, inplace=True) + self.df = self.df.reset_index(drop=True) diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 5d2f540c9..a30832775 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -2,12 +2,11 @@ import logging import os -from typing import Callable +from typing import TYPE_CHECKING import h5py import numpy as np import pdb2sql.transform -from numpy.typing import NDArray from scipy.spatial import distance_matrix from deeprank2.domain import edgestorage as Efeat @@ -18,6 +17,11 @@ from deeprank2.molstruct.residue import Residue from deeprank2.utils.grid import Augmentation, Grid, GridSettings, MapMethod +if TYPE_CHECKING: + from collections.abc import Callable + + from numpy.typing import NDArray + _log = logging.getLogger(__name__) @@ -26,9 +30,7 @@ def __init__(self, id_: Contact): self.id = id_ self.features = {} - def add_feature( - self, feature_name: str, feature_function: Callable[[Contact], float] - ): + def add_feature(self, feature_name: str, feature_function: Callable[[Contact], float]): feature_value = feature_function(self.id) self.features[feature_name] = feature_value @@ -43,12 +45,7 @@ def position2(self) -> np.array: def has_nan(self) -> bool: """Whether there are any NaN values in the edge's features.""" - - for feature_data in self.features.values(): - if np.any(np.isnan(feature_data)): - return True - - return False + return any(np.any(np.isnan(feature_data)) for feature_data in self.features.values()) class Node: @@ -69,11 +66,7 @@ def type(self): def has_nan(self) -> bool: """Whether there are any NaN values in the node's features.""" - - for feature_data in self.features.values(): - if np.any(np.isnan(feature_data)): - return True - return False + return any(np.any(np.isnan(feature_data)) for feature_data in self.features.values()) def add_feature( self, @@ -84,9 +77,7 @@ def add_feature( if len(feature_value.shape) != 1: shape_s = "x".join(feature_value.shape) - raise ValueError( - f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}" - ) + raise ValueError(f"Expected a 1-dimensional array for feature {feature_name}, but got {shape_s}") self.features[feature_name] = feature_value @@ -130,29 +121,29 @@ def edges(self) -> list[Node]: def has_nan(self) -> bool: """Whether there are any NaN values in the graph's features.""" - for node in self._nodes.values(): if node.has_nan(): return True + return any(edge.has_nan() for edge in self._edges.values()) - for edge in self._edges.values(): - if edge.has_nan(): - return True - - return False - - def _map_point_features(self, grid: Grid, method: MapMethod, # pylint: disable=too-many-arguments - feature_name: str, points: list[NDArray], - values: list[float | NDArray], - augmentation: Augmentation | None = None): - + def _map_point_features( + self, + grid: Grid, + method: MapMethod, + feature_name: str, + points: list[NDArray], + values: list[float | NDArray], + augmentation: Augmentation | None = None, + ): points = np.stack(points, axis=0) if augmentation is not None: - points = pdb2sql.transform.rot_xyz_around_axis(points, - augmentation.axis, - augmentation.angle, - self.center) + points = pdb2sql.transform.rot_xyz_around_axis( + points, + augmentation.axis, + augmentation.angle, + self.center, + ) for point_index in range(points.shape[0]): position = points[point_index] @@ -160,41 +151,58 @@ def _map_point_features(self, grid: Grid, method: MapMethod, # pylint: disable= grid.map_feature(position, feature_name, value, method) - def map_to_grid(self, grid: Grid, method: MapMethod, augmentation: Augmentation | None = None): - + def map_to_grid( + self, + grid: Grid, + method: MapMethod, + augmentation: Augmentation | None = None, + ): # order edge features by xyz point points = [] feature_values = {} for edge in self._edges.values(): - points += [edge.position1, edge.position2] for feature_name, feature_value in edge.features.items(): - feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value, feature_value] + feature_values[feature_name] = feature_values.get(feature_name, []) + [ # noqa: RUF005 (collection-literal-concatenation) + feature_value, + feature_value, + ] # map edge features to grid for feature_name, values in feature_values.items(): - self._map_point_features(grid, method, feature_name, points, values, augmentation) + self._map_point_features( + grid, + method, + feature_name, + points, + values, + augmentation, + ) # order node features by xyz point points = [] feature_values = {} for node in self._nodes.values(): - points.append(node.position) for feature_name, feature_value in node.features.items(): - feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value] + feature_values[feature_name] = feature_values.get(feature_name, []) + [feature_value] # noqa: RUF005 (collection-literal-concatenation) # map node features to grid for feature_name, values in feature_values.items(): - self._map_point_features(grid, method, feature_name, points, values, augmentation) + self._map_point_features( + grid, + method, + feature_name, + points, + values, + augmentation, + ) - def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals + def write_to_hdf5(self, hdf5_path: str): """Write a featured graph to an hdf5 file, according to deeprank standards.""" - with h5py.File(hdf5_path, "a") as hdf5_file: - # create groups to hold data graph_group = hdf5_file.require_group(self.id) node_features_group = graph_group.create_group(Nfeat.NODE) @@ -208,29 +216,23 @@ def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals # store node features node_key_list = list(self._nodes.keys()) - first_node_data = list(self._nodes.values())[0].features + first_node_data = next(iter(self._nodes.values())).features node_feature_names = list(first_node_data.keys()) for node_feature_name in node_feature_names: + node_feature_data = [node.features[node_feature_name] for node in self._nodes.values()] - node_feature_data = [ - node.features[node_feature_name] for node in self._nodes.values() - ] - - node_features_group.create_dataset( - node_feature_name, data=node_feature_data - ) + node_features_group.create_dataset(node_feature_name, data=node_feature_data) # identify edges edge_indices = [] edge_names = [] - first_edge_data = list(self._edges.values())[0].features + first_edge_data = next(iter(self._edges.values())).features edge_feature_names = list(first_edge_data.keys()) edge_feature_data = {name: [] for name in edge_feature_names} for edge_id, edge in self._edges.items(): - id1, id2 = edge_id node_index1 = node_key_list.index(id1) node_index2 = node_key_list.index(id2) @@ -239,21 +241,15 @@ def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals edge_names.append(f"{id1}-{id2}") for edge_feature_name in edge_feature_names: - edge_feature_data[edge_feature_name].append( - edge.features[edge_feature_name] - ) + edge_feature_data[edge_feature_name].append(edge.features[edge_feature_name]) # store edge names and indices - edge_feature_group.create_dataset( - Efeat.NAME, data=np.array(edge_names).astype("S") - ) + edge_feature_group.create_dataset(Efeat.NAME, data=np.array(edge_names).astype("S")) edge_feature_group.create_dataset(Efeat.INDEX, data=edge_indices) # store edge features for edge_feature_name in edge_feature_names: - edge_feature_group.create_dataset( - edge_feature_name, data=edge_feature_data[edge_feature_name] - ) + edge_feature_group.create_dataset(edge_feature_name, data=edge_feature_data[edge_feature_name]) # store target values score_group = graph_group.create_group(targets.VALUES) @@ -262,15 +258,11 @@ def write_to_hdf5(self, hdf5_path: str): # pylint: disable=too-many-locals @staticmethod def _find_unused_augmentation_name(unaugmented_id: str, hdf5_path: str) -> str: - prefix = f"{unaugmented_id}_" - entry_names_taken = [] if os.path.isfile(hdf5_path): - with h5py.File(hdf5_path, 'r') as hdf5_file: - for entry_name in hdf5_file: - if entry_name.startswith(prefix): - entry_names_taken.append(entry_name) + with h5py.File(hdf5_path, "r") as hdf5_file: + entry_names_taken = [entry_name for entry_name in hdf5_file if entry_name.startswith(prefix)] augmentation_count = 0 chosen_name = f"{prefix}{augmentation_count:03}" @@ -281,12 +273,12 @@ def _find_unused_augmentation_name(unaugmented_id: str, hdf5_path: str) -> str: return chosen_name def write_as_grid_to_hdf5( - self, hdf5_path: str, + self, + hdf5_path: str, settings: GridSettings, method: MapMethod, - augmentation: Augmentation | None = None + augmentation: Augmentation | None = None, ) -> str: - id_ = self.id if augmentation is not None: id_ = self._find_unused_augmentation_name(id_, hdf5_path) @@ -297,8 +289,7 @@ def write_as_grid_to_hdf5( grid.to_hdf5(hdf5_path) # store target values - with h5py.File(hdf5_path, 'a') as hdf5_file: - + with h5py.File(hdf5_path, "a") as hdf5_file: grp = hdf5_file[id_] targets_group = grp.require_group(targets.VALUES) @@ -312,16 +303,15 @@ def write_as_grid_to_hdf5( def get_all_chains(self) -> list[str]: if isinstance(self.nodes[0].id, Residue): - chains = set(str(res.chain).split()[1] for res in [node.id for node in self.nodes]) + chains = {str(res.chain).split()[1] for res in [node.id for node in self.nodes]} elif isinstance(self.nodes[0].id, Atom): - chains = set(str(res.chain).split()[1] for res in [node.id.residue for node in self.nodes]) + chains = {str(res.chain).split()[1] for res in [node.id.residue for node in self.nodes]} else: return None return list(chains) - @staticmethod - def build_graph( # pylint: disable=too-many-locals + def build_graph( nodes: list[Atom] | list[Residue], graph_id: str, max_edge_length: float, @@ -340,7 +330,6 @@ def build_graph( # pylint: disable=too-many-locals Raises: TypeError: if `nodes` argument contains a mix of different types. """ - if all(isinstance(node, Atom) for node in nodes): atoms = nodes NodeContact = AtomicContact @@ -370,7 +359,6 @@ def build_graph( # pylint: disable=too-many-locals for index1, index2 in index_pairs: if index1 != index2: - node1 = Node(nodes[index1]) node2 = Node(nodes[index2]) contact = NodeContact(node1.id, node2.id) diff --git a/deeprank2/utils/grid.py b/deeprank2/utils/grid.py index 9a8071e86..2c8f8a55f 100644 --- a/deeprank2/utils/grid.py +++ b/deeprank2/utils/grid.py @@ -55,10 +55,10 @@ class GridSettings: def __init__( self, points_counts: list[int], - sizes: list[float] + sizes: list[float], ): - assert len(points_counts) == 3 - assert len(sizes) == 3 + if len(points_counts) != 3 or len(sizes) != 3: + raise ValueError("Incorrect grid dimensions.") self._points_counts = points_counts self._sizes = sizes @@ -77,27 +77,24 @@ def points_counts(self) -> list[int]: class Grid: - """ - An instance of this class holds everything that the grid is made of: + """A 3D (volumetric) representation of a `Graph`. + + A Grid contains the following information: + - coordinates of points - names of features - - feature values on each point + - feature values on each point. """ def __init__(self, id_: str, center: list[float], settings: GridSettings): self.id = id_ - self._center = np.array(center) - self._settings = settings - self._set_mesh(self._center, settings) - self._features = {} def _set_mesh(self, center: NDArray, settings: GridSettings): """Builds the grid points.""" - half_size_x = settings.sizes[0] / 2 half_size_y = settings.sizes[1] / 2 half_size_z = settings.sizes[2] / 2 @@ -114,9 +111,7 @@ def _set_mesh(self, center: NDArray, settings: GridSettings): max_z = min_z + (settings.points_counts[2] - 1.0) * settings.resolutions[2] self._zs = np.linspace(min_z, max_z, num=settings.points_counts[2]) - self._ygrid, self._xgrid, self._zgrid = np.meshgrid( - self._ys, self._xs, self._zs - ) + self._ygrid, self._xgrid, self._zgrid = np.meshgrid(self._ys, self._xs, self._zs) @property def center(self) -> NDArray: @@ -155,7 +150,6 @@ def add_feature_values(self, feature_name: str, data: NDArray): This method may be called repeatedly to add on to existing grid point values. """ - if feature_name not in self._features: self._features[feature_name] = data else: @@ -164,42 +158,33 @@ def add_feature_values(self, feature_name: str, data: NDArray): def _get_mapped_feature_gaussian( self, position: NDArray, - value: float + value: float, ) -> NDArray: - beta = 1.0 fx, fy, fz = position - distances = np.sqrt( - (self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2 - ) + distances = np.sqrt((self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2) return value * np.exp(-beta * distances) - def _get_mapped_feature_fast_gaussian( - self, position: NDArray, value: float - ) -> NDArray: - + def _get_mapped_feature_fast_gaussian(self, position: NDArray, value: float) -> NDArray: beta = 1.0 cutoff = 5.0 * beta fx, fy, fz = position - distances = np.sqrt( - (self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2 - ) + distances = np.sqrt((self.xgrid - fx) ** 2 + (self.ygrid - fy) ** 2 + (self.zgrid - fz) ** 2) data = np.zeros(distances.shape) - data[distances < cutoff] = value * np.exp( - -beta * distances[distances < cutoff] - ) + data[distances < cutoff] = value * np.exp(-beta * distances[distances < cutoff]) return data def _get_mapped_feature_bsp_line( - self, position: NDArray, value: float + self, + position: NDArray, + value: float, ) -> NDArray: - order = 4 fx, fy, fz = position @@ -211,10 +196,11 @@ def _get_mapped_feature_bsp_line( return value * bsp_data - def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals - self, position: NDArray, value: float + def _get_mapped_feature_nearest_neighbour( + self, + position: NDArray, + value: float, ) -> NDArray: - fx, _, _ = position distances_x = np.abs(self.xs - fx) distances_y = np.abs(self.ys - fx) @@ -239,9 +225,7 @@ def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals weight_products = list(itertools.product(weights_x, weights_y, weights_z)) weights = [np.sum(p) for p in weight_products] - neighbour_data = np.zeros( - (self.xs.shape[0], self.ys.shape[0], self.zs.shape[0]) - ) + neighbour_data = np.zeros((self.xs.shape[0], self.ys.shape[0], self.zs.shape[0])) for point_index, point in enumerate(points): weight = weights[point_index] @@ -250,7 +234,11 @@ def _get_mapped_feature_nearest_neighbour( # pylint: disable=too-many-locals return neighbour_data - def _get_atomic_density_koes(self, position: NDArray, vanderwaals_radius: float) -> NDArray: + def _get_atomic_density_koes( + self, + position: NDArray, + vanderwaals_radius: float, + ) -> NDArray: """Function to map individual atomic density on the grid. The formula is equation (1) of the Koes paper @@ -259,20 +247,19 @@ def _get_atomic_density_koes(self, position: NDArray, vanderwaals_radius: float) Returns: NDArray: The mapped density. """ - - distances = np.sqrt(np.square(self.xgrid - position[0]) + - np.square(self.ygrid - position[1]) + - np.square(self.zgrid - position[2])) + distances = np.sqrt(np.square(self.xgrid - position[0]) + np.square(self.ygrid - position[1]) + np.square(self.zgrid - position[2])) density_data = np.zeros(distances.shape) indices_close = distances < vanderwaals_radius indices_far = (distances >= vanderwaals_radius) & (distances < 1.5 * vanderwaals_radius) - density_data[indices_close] = np.exp(-2.0 * np.square(distances[indices_close]) / np.square(vanderwaals_radius)) - density_data[indices_far] = 4.0 / np.square(np.e) / np.square(vanderwaals_radius) * np.square(distances[indices_far]) - \ - 12.0 / np.square(np.e) / vanderwaals_radius * distances[indices_far] + \ - 9.0 / np.square(np.e) + density_data[indices_close] = np.exp(-2.0 * np.square(distances[indices_close]) / np.square(vanderwaals_radius)) + density_data[indices_far] = ( + 4.0 / np.square(np.e) / np.square(vanderwaals_radius) * np.square(distances[indices_far]) + - 12.0 / np.square(np.e) / vanderwaals_radius * distances[indices_far] + + 9.0 / np.square(np.e) + ) return density_data @@ -287,7 +274,6 @@ def map_feature( The feature_value should either be a single number or a one-dimensional array. """ - # determine whether we're dealing with a single number of multiple numbers: index_names_values = [] if isinstance(feature_value, float): @@ -303,7 +289,6 @@ def map_feature( # map the data to the grid for index_name, value in index_names_values: - if method == MapMethod.GAUSSIAN: grid_data = self._get_mapped_feature_gaussian(position, value) @@ -321,9 +306,7 @@ def map_feature( def to_hdf5(self, hdf5_path: str): """Write the grid data to hdf5, according to deeprank standards.""" - with h5py.File(hdf5_path, "a") as hdf5_file: - # create a group to hold everything grid_group = hdf5_file.require_group(self.id) @@ -337,7 +320,6 @@ def to_hdf5(self, hdf5_path: str): # store grid features features_group = grid_group.require_group(gridstorage.MAPPED_FEATURES) for feature_name, feature_data in self.features.items(): - features_group.create_dataset( feature_name, data=feature_data, diff --git a/deeprank2/utils/parsing/__init__.py b/deeprank2/utils/parsing/__init__.py index 675918cd3..8211b9e80 100644 --- a/deeprank2/utils/parsing/__init__.py +++ b/deeprank2/utils/parsing/__init__.py @@ -10,36 +10,33 @@ _log = logging.getLogger(__name__) -_forcefield_directory_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../../domain/forcefield')) +_forcefield_directory_path = os.path.realpath(os.path.join(os.path.dirname(__file__), "../../domain/forcefield")) + class AtomicForcefield: def __init__(self): - top_path = os.path.join( - _forcefield_directory_path, - "protein-allhdg5-5_new.top") - with open(top_path, 'rt', encoding = 'utf-8') as f: + top_path = os.path.join(_forcefield_directory_path, "protein-allhdg5-5_new.top") + with open(top_path, encoding="utf-8") as f: self._top_rows = {(row.residue_name, row.atom_name): row for row in TopParser.parse(f)} patch_path = os.path.join(_forcefield_directory_path, "patch.top") - with open(patch_path, 'rt', encoding = 'utf-8') as f: + with open(patch_path, encoding="utf-8") as f: self._patch_actions = PatchParser.parse(f) - residue_class_path = os.path.join( - _forcefield_directory_path, "residue-classes") - with open(residue_class_path, 'rt', encoding = 'utf-8') as f: + residue_class_path = os.path.join(_forcefield_directory_path, "residue-classes") + with open(residue_class_path, encoding="utf-8") as f: self._residue_class_criteria = ResidueClassParser.parse(f) - param_path = os.path.join( - _forcefield_directory_path, - "protein-allhdg5-4_new.param") - with open(param_path, 'rt', encoding = 'utf-8') as f: + param_path = os.path.join(_forcefield_directory_path, "protein-allhdg5-4_new.param") + with open(param_path, encoding="utf-8") as f: self._vanderwaals_parameters = ParamParser.parse(f) def _find_matching_residue_class(self, residue: Residue): for criterium in self._residue_class_criteria: if criterium.matches( - residue.amino_acid.three_letter_code, [ - atom.name for atom in residue.atoms]): + residue.amino_acid.three_letter_code, + [atom.name for atom in residue.atoms], + ): return criterium.class_name return None @@ -49,8 +46,9 @@ def get_vanderwaals_parameters(self, atom: Atom): if atom.residue.amino_acid is None: _log.warning(f"no amino acid for {atom}; three letter code set to XXX") - residue_name = 'XXX' - else: residue_name = atom.residue.amino_acid.three_letter_code + residue_name = "XXX" + else: + residue_name = atom.residue.amino_acid.three_letter_code type_ = None @@ -63,25 +61,21 @@ def get_vanderwaals_parameters(self, atom: Atom): residue_class = self._find_matching_residue_class(atom.residue) if residue_class is not None: for action in self._patch_actions: - if action.type in [PatchActionType.MODIFY, PatchActionType.ADD] and \ - residue_class == action.selection.residue_type and "TYPE" in action: - + if action.type in [PatchActionType.MODIFY, PatchActionType.ADD] and residue_class == action.selection.residue_type and "TYPE" in action: type_ = action["TYPE"] - if type_ is None: # pylint: disable=no-else-return + if type_ is None: _log.warning(f"Atom {atom} is unknown to the forcefield; vanderwaals_parameters set to (0.0, 0.0, 0.0, 0.0)") return VanderwaalsParam(0.0, 0.0, 0.0, 0.0) - else: - return self._vanderwaals_parameters[type_] - + return self._vanderwaals_parameters[type_] def get_charge(self, atom: Atom): - """ - Args: - atom(Atom): the atom to get the charge for - Returns(float): the charge of the given atom - """ + """Get the charge of a given `Atom`. + Args: + atom(Atom): the atom to get the charge for + Returns(float): the charge of the given atom. + """ atom_name = atom.name amino_acid_code = atom.residue.amino_acid.three_letter_code @@ -96,17 +90,13 @@ def get_charge(self, atom: Atom): residue_class = self._find_matching_residue_class(atom.residue) if residue_class is not None: for action in self._patch_actions: - if action.type in [ - PatchActionType.MODIFY, - PatchActionType.ADD] and residue_class == action.selection.residue_type: - + if action.type in [PatchActionType.MODIFY, PatchActionType.ADD] and residue_class == action.selection.residue_type: charge = float(action["CHARGE"]) - if charge is None: # pylint: disable=no-else-return + if charge is None: _log.warning(f"Atom {atom} is unknown to the forcefield; charge is set to 0.0") return 0.0 - else: - return charge + return charge atomic_forcefield = AtomicForcefield() diff --git a/deeprank2/utils/parsing/patch.py b/deeprank2/utils/parsing/patch.py index b03df3a0f..728ff94db 100644 --- a/deeprank2/utils/parsing/patch.py +++ b/deeprank2/utils/parsing/patch.py @@ -30,9 +30,7 @@ def __getitem__(self, key): class PatchParser: STRING_VAR_PATTERN = re.compile(r"([A-Z]+)=([A-Z0-9]+)") NUMBER_VAR_PATTERN = re.compile(r"([A-Z]+)=(\-?[0-9]+\.[0-9]+)") - ACTION_PATTERN = re.compile( - r"^([A-Z]{3,4})\s+([A-Z]+)\s+ATOM\s+([A-Z0-9]{1,3})\s+(.*)$" - ) + ACTION_PATTERN = re.compile(r"^([A-Z]{3,4})\s+([A-Z]+)\s+ATOM\s+([A-Z0-9]{1,3})\s+(.*)$") @staticmethod def _parse_action_type(s): @@ -40,18 +38,18 @@ def _parse_action_type(s): if type_.name == s: return type_ - raise ValueError(f"unmatched residue action: {repr(s)}") + raise ValueError(f"Unmatched residue action: {s!r}") @staticmethod def parse(file_): result = [] for line in file_: - if line.startswith("#") or line.startswith("!") or len(line.strip()) == 0: + if line.startswith(("#", "!")) or len(line.strip()) == 0: continue m = PatchParser.ACTION_PATTERN.match(line) if not m: - raise ValueError(f"Unmatched patch action: {repr(line)}") + raise ValueError(f"Unmatched patch action: {line!r}") residue_type = m.group(1) action_type = PatchParser._parse_action_type(m.group(2)) @@ -63,9 +61,5 @@ def parse(file_): for w in PatchParser.NUMBER_VAR_PATTERN.finditer(m.group(4)): kwargs[w.group(1)] = float(w.group(2)) - result.append( - PatchAction( - action_type, PatchSelection(residue_type, atom_name), kwargs - ) - ) + result.append(PatchAction(action_type, PatchSelection(residue_type, atom_name), kwargs)) return result diff --git a/deeprank2/utils/parsing/pssm.py b/deeprank2/utils/parsing/pssm.py index 069dc1074..26a99f1d0 100644 --- a/deeprank2/utils/parsing/pssm.py +++ b/deeprank2/utils/parsing/pssm.py @@ -13,17 +13,14 @@ def parse_pssm(file_: TextIO, chain: Chain) -> PssmTable: file_ (python text file object): The pssm file. chain (:class:`Chain`): The chain that the pssm file represents, residues from this chain must match the pssm file. - Returns + Returns: PssmTable: The position-specific scoring table, parsed from the pssm file. """ - conservation_rows = {} # Read the pssm header. header = next(file_).split() - column_indices = { - column_name.strip(): index for index, column_name in enumerate(header) - } + column_indices = {column_name.strip(): index for index, column_name in enumerate(header)} for line in file_: row = line.split() @@ -35,7 +32,6 @@ def parse_pssm(file_: TextIO, chain: Chain) -> PssmTable: # exceptions. pdb_residue_number_string = row[column_indices["pdbresi"]] if pdb_residue_number_string[-1].isalpha(): - pdb_residue_number = int(pdb_residue_number_string[:-1]) pdb_insertion_code = pdb_residue_number_string[-1] else: @@ -47,10 +43,7 @@ def parse_pssm(file_: TextIO, chain: Chain) -> PssmTable: # Build the pssm row information_content = float(row[column_indices["IC"]]) - conservations = { - amino_acid: float(row[column_indices[amino_acid.one_letter_code]]) - for amino_acid in amino_acids - } + conservations = {amino_acid: float(row[column_indices[amino_acid.one_letter_code]]) for amino_acid in amino_acids} conservation_rows[residue] = PssmRow(conservations, information_content) diff --git a/deeprank2/utils/parsing/residue.py b/deeprank2/utils/parsing/residue.py index 4609029aa..b37b734fe 100644 --- a/deeprank2/utils/parsing/residue.py +++ b/deeprank2/utils/parsing/residue.py @@ -17,31 +17,22 @@ def __init__( self.absent_atom_names = absent_atom_names def matches(self, amino_acid_name: str, atom_names: list[str]) -> bool: - # check the amino acid name - if self.amino_acid_names != "all": - if not any( - - amino_acid_name == crit_amino_acid_name - for crit_amino_acid_name in self.amino_acid_names - - ): - - return False + if self.amino_acid_names != "all" and not any(amino_acid_name == crit_amino_acid_name for crit_amino_acid_name in self.amino_acid_names): + return False # check the atom names that should be absent if any(atom_name in self.absent_atom_names for atom_name in atom_names): - return False # check the atom names that should be present if not all(atom_name in atom_names for atom_name in self.present_atom_names): - return False # all checks passed return True + class ResidueClassParser: _RESIDUE_CLASS_PATTERN = re.compile(r"([A-Z]{3,4}) *\: *name *\= *(all|[A-Z]{3})") _RESIDUE_ATOMS_PATTERN = re.compile(r"(present|absent)\(([A-Z0-9\, ]+)\)") @@ -52,16 +43,14 @@ def parse(file_): for line in file_: match = ResidueClassParser._RESIDUE_CLASS_PATTERN.match(line) if not match: - raise ValueError(f"unparsable line: '{line}'") + raise ValueError(f"Unparsable line: '{line}'") class_name = match.group(1) amino_acid_names = ResidueClassParser._parse_amino_acids(match.group(2)) present_atom_names = [] absent_atom_names = [] - for match in ResidueClassParser._RESIDUE_ATOMS_PATTERN.finditer( - line[match.end() :] - ): + for match in ResidueClassParser._RESIDUE_ATOMS_PATTERN.finditer(line[match.end() :]): # noqa: B020 (loop-variable-overrides-iterator) atom_names = [name.strip() for name in match.group(2).split(",")] if match.group(1) == "present": present_atom_names.extend(atom_names) @@ -69,11 +58,7 @@ def parse(file_): elif match.group(1) == "absent": absent_atom_names.extend(atom_names) - result.append( - ResidueClassCriterium( - class_name, amino_acid_names, present_atom_names, absent_atom_names - ) - ) + result.append(ResidueClassCriterium(class_name, amino_acid_names, present_atom_names, absent_atom_names)) return result @staticmethod diff --git a/deeprank2/utils/parsing/top.py b/deeprank2/utils/parsing/top.py index 27b58d140..84b6f8fa1 100644 --- a/deeprank2/utils/parsing/top.py +++ b/deeprank2/utils/parsing/top.py @@ -4,6 +4,7 @@ logging.getLogger(__name__) + class TopRowObject: def __init__( self, @@ -18,11 +19,10 @@ def __init__( def __getitem__(self, key): return self.kwargs[key] + class TopParser: _VAR_PATTERN = re.compile(r"([^\s]+)\s*=\s*([^\s\(\)]+|\(.*\))") - _LINE_PATTERN = re.compile( - r"^([A-Z0-9]{3})\s+atom\s+([A-Z0-9]{1,4})\s+(.+)\s+end\s*(\s+\!\s+[ _A-Za-z0-9]+)?$" - ) + _LINE_PATTERN = re.compile(r"^([A-Z0-9]{3})\s+atom\s+([A-Z0-9]{1,4})\s+(.+)\s+end\s*(\s+\!\s+[ _A-Za-z0-9]+)?$") _NUMBER_PATTERN = re.compile(r"\-?[0-9]+(\.[0-9]+)?") @staticmethod @@ -39,9 +39,7 @@ def parse(file_): kwargs = {} for w in TopParser._VAR_PATTERN.finditer(m.group(3)): - kwargs[w.group(1).lower().strip()] = TopParser._parse_value( - w.group(2).strip() - ) + kwargs[w.group(1).lower().strip()] = TopParser._parse_value(w.group(2).strip()) result.append(TopRowObject(residue_name, atom_name, kwargs)) diff --git a/deeprank2/utils/parsing/vdwparam.py b/deeprank2/utils/parsing/vdwparam.py index fdc0f6989..73a594436 100644 --- a/deeprank2/utils/parsing/vdwparam.py +++ b/deeprank2/utils/parsing/vdwparam.py @@ -1,10 +1,11 @@ class VanderwaalsParam: def __init__( - self, - epsilon_main: float, - sigma_main: float, - epsilon_14: float, - sigma_14: float): + self, + epsilon_main: float, + sigma_main: float, + epsilon_14: float, + sigma_14: float, + ): self.epsilon_main = epsilon_main self.sigma_main = sigma_main self.epsilon_14 = epsilon_14 diff --git a/deeprank2/utils/pssmdata.py b/deeprank2/utils/pssmdata.py index b10a27ce1..49134d1ac 100644 --- a/deeprank2/utils/pssmdata.py +++ b/deeprank2/utils/pssmdata.py @@ -4,7 +4,11 @@ class PssmRow: """Holds data for one position-specific scoring matrix row.""" - def __init__(self, conservations: dict[AminoAcid, float], information_content: float): + def __init__( + self, + conservations: dict[AminoAcid, float], + information_content: float, + ): self._conservations = conservations self._information_content = information_content @@ -37,5 +41,4 @@ def __getitem__(self, residue) -> PssmRow: def update(self, other): """Can be used to merge two non-overlapping scoring tables.""" - - self._rows.update(other._rows) # pylint: disable=protected-access + self._rows.update(other._rows) # noqa: SLF001 (private-member-access) diff --git a/docs/conf.py b/docs/conf.py index a3255c1f2..35f734134 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,48 +13,50 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import configparser +import configparser # noqa: F401 (unused-import) # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os -import toml import sys +import toml + autodoc_mock_imports = [ - 'numpy', - 'scipy', - 'h5py', - 'sklearn', - 'scipy.signal', - 'torch', - 'torch.utils', - 'torch.utils.data', - 'matplotlib', - 'matplotlib.pyplot', - 'torch.autograd', - 'torch.nn', - 'torch.optim', - 'torch.cuda', - 'torch.distributions', - 'torch_sparse', - 'torch_scatter', - 'torch_cluster', - 'torch-spline-conv', - 'pdb2sql', - 'networkx', - 'mendeleev', - 'pandas', - 'tqdm', - 'horovod', - 'numba', - 'Bio', - 'torch_geometric', - 'community', - 'markov_clustering'] - -sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../')) + "numpy", + "scipy", + "h5py", + "sklearn", + "scipy.signal", + "torch", + "torch.utils", + "torch.utils.data", + "matplotlib", + "matplotlib.pyplot", + "torch.autograd", + "torch.nn", + "torch.optim", + "torch.cuda", + "torch.distributions", + "torch_sparse", + "torch_scatter", + "torch_cluster", + "torch-spline-conv", + "pdb2sql", + "networkx", + "mendeleev", + "pandas", + "tqdm", + "horovod", + "numba", + "Bio", + "torch_geometric", + "community", + "markov_clustering", +] + +sys.path.insert(0, os.path.abspath(".")) +sys.path.insert(0, os.path.abspath("../")) # -- General configuration ------------------------------------------------ @@ -67,32 +69,32 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'myst_parser' + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "myst_parser", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'deeprank2' -author = 'Sven van der Burg, Giulia Crocioni, Dani Bodor' +project = "deeprank2" +author = "Sven van der Burg, Giulia Crocioni, Dani Bodor" copyright = f"2022, {author}" # The version info for the project you're documenting, acts as replacement for @@ -100,9 +102,9 @@ # built documents. # # The short X.Y version. -with open('./../pyproject.toml', 'r') as f: +with open("./../pyproject.toml", "r") as f: toml_file = toml.load(f) - version = toml_file['project']['version'] + version = toml_file["project"]["version"] # The full version, including alpha/beta/rc tags. release = version @@ -111,15 +113,15 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = 'en' +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -136,7 +138,7 @@ # else: # html_theme = 'classic' -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # html_logo = "qmctorch_white.png" # Theme options are theme-specific and customize the look and feel of a theme @@ -151,7 +153,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -159,11 +161,11 @@ # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { - '**': [ - 'globaltoc.html', - 'relations.html', # needs 'show_related': True theme option to display - 'sourcelink.html', - 'searchbox.html', + "**": [ + "globaltoc.html", + "relations.html", # needs 'show_related': True theme option to display + "sourcelink.html", + "searchbox.html", ] } @@ -171,14 +173,14 @@ # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'deeprank2' +htmlhelp_basename = "deeprank2" # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/', None), - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'pytorch': ('http://pytorch.org/docs/1.4.0/', None), + "python": ("https://docs.python.org/", None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "pytorch": ("http://pytorch.org/docs/1.4.0/", None), } -autoclass_content = 'init' -autodoc_member_order = 'bysource' +autoclass_content = "init" +autodoc_member_order = "bysource" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/domain/__init__.py b/tests/domain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/domain/test_aminoacidlist.py b/tests/domain/test_aminoacidlist.py index 7bafb2c89..f734edd36 100644 --- a/tests/domain/test_aminoacidlist.py +++ b/tests/domain/test_aminoacidlist.py @@ -1,7 +1,6 @@ import numpy as np -from deeprank2.domain.aminoacidlist import (amino_acids, cysteine, lysine, - pyrrolysine, selenocysteine) +from deeprank2.domain.aminoacidlist import amino_acids, cysteine, lysine, pyrrolysine, selenocysteine # Exceptions selenocysteine and pyrrolysine are due to them having the same index as their canonical counterpart. # This is not an issue while selenocysteine and pyrrolysine are not part of amino_acids. @@ -11,9 +10,9 @@ [lysine, pyrrolysine], ] -def test_all_different_onehot(): - for aa1, aa2 in zip(amino_acids, amino_acids): +def test_all_different_onehot(): + for aa1, aa2 in zip(amino_acids, amino_acids, strict=True): if aa1 == aa2: continue @@ -23,4 +22,4 @@ def test_all_different_onehot(): if (aa1 in EXCEPTIONS[0] and aa2 in EXCEPTIONS[0]) or (aa1 in EXCEPTIONS[1] and aa2 in EXCEPTIONS[1]): assert np.all(aa1.onehot == aa2.onehot) else: - raise AssertionError(f"one-hot index {aa1.index} is occupied by both {aa1} and {aa2}") from e + raise AssertionError(f"One-hot index {aa1.index} is occupied by both {aa1} and {aa2}") from e diff --git a/tests/domain/test_forcefield.py b/tests/domain/test_forcefield.py index 11a9ec522..ef4196c3a 100644 --- a/tests/domain/test_forcefield.py +++ b/tests/domain/test_forcefield.py @@ -1,30 +1,29 @@ -from deeprank2.domain.aminoacidlist import arginine, glutamate -from deeprank2.utils.buildgraph import get_structure from pdb2sql import pdb2sql +from deeprank2.domain.aminoacidlist import arginine, glutamate +from deeprank2.utils.buildgraph import get_structure from deeprank2.utils.parsing import atomic_forcefield def test_atomic_forcefield(): - pdb = pdb2sql("tests/data/pdb/101M/101M.pdb") try: structure = get_structure(pdb, "101M") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) # The arginine C-zeta should get a positive charge - arg = [r for r in structure.get_chain("A").residues if r.amino_acid == arginine][0] - cz = [a for a in arg.atoms if a.name == "CZ"][0] + arg = next(r for r in structure.get_chain("A").residues if r.amino_acid == arginine) + cz = next(a for a in arg.atoms if a.name == "CZ") assert atomic_forcefield.get_charge(cz) == 0.640 # The glutamate O-epsilon should get a negative charge - glu = [r for r in structure.get_chain("A").residues if r.amino_acid == glutamate][0] - oe2 = [a for a in glu.atoms if a.name == "OE2"][0] + glu = next(r for r in structure.get_chain("A").residues if r.amino_acid == glutamate) + oe2 = next(a for a in glu.atoms if a.name == "OE2") assert atomic_forcefield.get_charge(oe2) == -0.800 # The forcefield should treat terminal oxygen differently - oxt = [a for a in structure.get_atoms() if a.name == "OXT"][0] - o = [a for a in oxt.residue.atoms if a.name == "O"][0] + oxt = next(a for a in structure.get_atoms() if a.name == "OXT") + o = next(a for a in oxt.residue.atoms if a.name == "O") assert atomic_forcefield.get_charge(oxt) == -0.800 assert atomic_forcefield.get_charge(o) == -0.800 diff --git a/tests/features/__init__.py b/tests/features/__init__.py index c5c60aa90..8521d2b35 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -6,33 +6,29 @@ from deeprank2.molstruct.aminoacid import AminoAcid from deeprank2.molstruct.residue import Residue, SingleResidueVariant from deeprank2.molstruct.structure import Chain, PDBStructure -from deeprank2.utils.buildgraph import (get_residue_contact_pairs, - get_structure, - get_surrounding_residues) +from deeprank2.utils.buildgraph import get_residue_contact_pairs, get_structure, get_surrounding_residues from deeprank2.utils.graph import Graph from deeprank2.utils.parsing.pssm import parse_pssm def _get_residue(chain: Chain, number: int) -> Residue: - """ Get the Residue from its Chain and number - """ + """Get the Residue from its Chain and number.""" for residue in chain.residues: if residue.number == number: return residue raise ValueError(f"Not found: {number}") -def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001 +def build_testgraph( pdb_path: str, - detail: Literal['atom', 'residue'], + detail: Literal["atom", "residue"], influence_radius: float, max_edge_length: float, central_res: int | None = None, - variant: AminoAcid |None = None, + variant: AminoAcid | None = None, chain_ids: str | tuple[str, str] | None = None, ) -> tuple[Graph, SingleResidueVariant | None]: - - """ Creates a Graph object for feature tests. + """Creates a Graph object for feature tests. Args: pdb_path (str): Path of pdb file. @@ -54,29 +50,24 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq Graph: As generated by Graph.build_graph SingleResidueVariant: returns None if central_res is None """ - pdb = pdb2sql(pdb_path) try: structure: PDBStructure = get_structure(pdb, Path(pdb_path).stem) finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) if not central_res: - nodes = set([]) + nodes = set() if not chain_ids: chains = (structure.chains[0].id, structure.chains[1].id) else: chains = [structure.get_chain(chain_id) for chain_id in chain_ids] - for residue1, residue2 in get_residue_contact_pairs( - pdb_path, structure, - chains[0], chains[1], - influence_radius - ): - if detail == 'residue': + for residue1, residue2 in get_residue_contact_pairs(pdb_path, structure, chains[0], chains[1], influence_radius): + if detail == "residue": nodes.add(residue1) nodes.add(residue2) - elif detail == 'atom': + elif detail == "atom": for atom in residue1.atoms: nodes.add(atom) for atom in residue2.atoms: @@ -95,14 +86,17 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) try: - with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: + with open( + f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", + encoding="utf-8", + ) as f: chain.pssm = parse_pssm(f, chain) except FileNotFoundError: pass - if detail == 'residue': + if detail == "residue": return Graph.build_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) - if detail == 'atom': - atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) + if detail == "atom": + atoms = {atom for residue in surrounding_residues for atom in residue.atoms} return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) raise TypeError('detail must be "atom" or "residue"') diff --git a/tests/features/test_components.py b/tests/features/test_components.py index 07c1c7009..f00b91c41 100644 --- a/tests/features/test_components.py +++ b/tests/features/test_components.py @@ -11,7 +11,7 @@ def test_atom_features(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=10, max_edge_length=10, central_res=25, @@ -25,7 +25,7 @@ def test_aminoacid_features(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=10, max_edge_length=10, central_res=25, diff --git a/tests/features/test_conservation.py b/tests/features/test_conservation.py index 30923716a..a15a43bfd 100644 --- a/tests/features/test_conservation.py +++ b/tests/features/test_conservation.py @@ -12,7 +12,7 @@ def test_conservation_residue(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=10, max_edge_length=10, central_res=25, @@ -26,14 +26,14 @@ def test_conservation_residue(): Nfeat.CONSERVATION, Nfeat.INFOCONTENT, ): - assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f'all 0s found for {feature_name}' + assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f"all 0s found for {feature_name}" def test_conservation_atom(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=10, max_edge_length=10, central_res=25, @@ -47,14 +47,14 @@ def test_conservation_atom(): Nfeat.CONSERVATION, Nfeat.INFOCONTENT, ): - assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f'all 0s found for {feature_name}' + assert np.any([node.features[feature_name] != 0.0 for node in graph.nodes]), f"all 0s found for {feature_name}" def test_no_pssm_file_error(): pdb_path = "tests/data/pdb/1CRN/1CRN.pdb" graph, variant = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=10, max_edge_length=10, central_res=17, diff --git a/tests/features/test_contact.py b/tests/features/test_contact.py index 95d40dba0..ab1ef06f5 100644 --- a/tests/features/test_contact.py +++ b/tests/features/test_contact.py @@ -4,8 +4,7 @@ from pdb2sql import pdb2sql from deeprank2.domain import edgestorage as Efeat -from deeprank2.features.contact import (add_features, covalent_cutoff, - cutoff_13, cutoff_14) +from deeprank2.features.contact import add_features, covalent_cutoff, cutoff_13, cutoff_14 from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.pair import AtomicContact, ResidueContact from deeprank2.molstruct.structure import Chain @@ -19,9 +18,7 @@ def _get_atom(chain: Chain, residue_number: int, atom_name: str) -> Atom: for atom in residue.atoms: if atom.name == atom_name: return atom - raise ValueError( - f"Not found: chain {chain.id} residue {residue_number} atom {atom_name}" - ) + raise ValueError(f"Not found: chain {chain.id} residue {residue_number} atom {atom_name}") def _wrap_in_graph(edge: Edge): @@ -30,23 +27,22 @@ def _wrap_in_graph(edge: Edge): return g -def _get_contact( # pylint: disable=too-many-arguments +def _get_contact( pdb_id: str, residue_num1: int, atom_name1: str, residue_num2: int, atom_name2: str, residue_level: bool = False, - chains: tuple[str,str] = None, + chains: tuple[str, str] | None = None, ) -> Edge: - pdb_path = f"tests/data/pdb/{pdb_id}/{pdb_id}.pdb" pdb = pdb2sql(pdb_path) try: structure = get_structure(pdb, pdb_id) finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) if not chains: chains = [structure.chains[0], structure.chains[0]] @@ -56,159 +52,124 @@ def _get_contact( # pylint: disable=too-many-arguments if not residue_level: contact = AtomicContact( _get_atom(chains[0], residue_num1, atom_name1), - _get_atom(chains[1], residue_num2, atom_name2) + _get_atom(chains[1], residue_num2, atom_name2), ) else: - contact = ResidueContact( - chains[0].residues[residue_num1], - chains[1].residues[residue_num2] - ) + contact = ResidueContact(chains[0].residues[residue_num1], chains[1].residues[residue_num2]) edge_obj = Edge(contact) add_features(pdb_path, _wrap_in_graph(edge_obj)) - assert not np.isnan(edge_obj.features[Efeat.VDW]), 'isnan vdw' - assert not np.isnan(edge_obj.features[Efeat.ELEC]), 'isnan electrostatic' - assert not np.isnan(edge_obj.features[Efeat.DISTANCE]), 'isnan distance' - assert not np.isnan(edge_obj.features[Efeat.SAMECHAIN]), 'isnan samechain' - assert not np.isnan(edge_obj.features[Efeat.COVALENT]), 'isnan covalent' + assert not np.isnan(edge_obj.features[Efeat.VDW]), "isnan vdw" + assert not np.isnan(edge_obj.features[Efeat.ELEC]), "isnan electrostatic" + assert not np.isnan(edge_obj.features[Efeat.DISTANCE]), "isnan distance" + assert not np.isnan(edge_obj.features[Efeat.SAMECHAIN]), "isnan samechain" + assert not np.isnan(edge_obj.features[Efeat.COVALENT]), "isnan covalent" if not residue_level: - assert not np.isnan(edge_obj.features[Efeat.SAMERES]), 'isnan sameres' + assert not np.isnan(edge_obj.features[Efeat.SAMERES]), "isnan sameres" return edge_obj def test_covalent_pair(): - """MET 0: N - CA, covalent pair (at 1.49 A distance). Should have 0 vanderwaals and electrostatic energies. - """ - - edge_covalent = _get_contact('101M', 0, "N", 0, "CA") + """MET 0: N - CA, covalent pair (at 1.49 A distance). Should have 0 vanderwaals and electrostatic energies.""" + edge_covalent = _get_contact("101M", 0, "N", 0, "CA") assert edge_covalent.features[Efeat.DISTANCE] < covalent_cutoff - assert edge_covalent.features[Efeat.VDW] == 0.0, 'nonzero vdw energy for covalent pair' - assert edge_covalent.features[Efeat.ELEC] == 0.0, 'nonzero electrostatic energy for covalent pair' - assert edge_covalent.features[Efeat.COVALENT] == 1.0, 'covalent pair not recognized as covalent' + assert edge_covalent.features[Efeat.VDW] == 0.0, "nonzero vdw energy for covalent pair" + assert edge_covalent.features[Efeat.ELEC] == 0.0, "nonzero electrostatic energy for covalent pair" + assert edge_covalent.features[Efeat.COVALENT] == 1.0, "covalent pair not recognized as covalent" def test_13_pair(): - """MET 0: N - CB, 1-3 pair (at 2.47 A distance). Should have 0 vanderwaals and electrostatic energies. - """ - - edge_13 = _get_contact('101M', 0, "N", 0, "CB") + """MET 0: N - CB, 1-3 pair (at 2.47 A distance). Should have 0 vanderwaals and electrostatic energies.""" + edge_13 = _get_contact("101M", 0, "N", 0, "CB") assert edge_13.features[Efeat.DISTANCE] < cutoff_13 - assert edge_13.features[Efeat.VDW] == 0.0, 'nonzero vdw energy for 1-3 pair' - assert edge_13.features[Efeat.ELEC] == 0.0, 'nonzero electrostatic energy for 1-3 pair' - assert edge_13.features[Efeat.COVALENT] == 0.0, '1-3 pair recognized as covalent' + assert edge_13.features[Efeat.VDW] == 0.0, "nonzero vdw energy for 1-3 pair" + assert edge_13.features[Efeat.ELEC] == 0.0, "nonzero electrostatic energy for 1-3 pair" + assert edge_13.features[Efeat.COVALENT] == 0.0, "1-3 pair recognized as covalent" def test_very_close_opposing_chains(): - """ChainA THR 118 O - ChainB ARG 30 NH1 (3.55 A). Should have non-zero energy despite close contact, because opposing chains. - """ - - opposing_edge = _get_contact('1A0Z', 118, "O", 30, "NH1", chains=('A', 'B')) + """ChainA THR 118 O - ChainB ARG 30 NH1 (3.55 A). Should have non-zero energy despite close contact, because opposing chains.""" + opposing_edge = _get_contact("1A0Z", 118, "O", 30, "NH1", chains=("A", "B")) assert opposing_edge.features[Efeat.DISTANCE] < cutoff_13 assert opposing_edge.features[Efeat.ELEC] != 0.0 assert opposing_edge.features[Efeat.VDW] != 0.0 def test_14_pair(): - """MET 0: N - CG, 1-4 pair (at 4.12 A distance). Should have non-zero electrostatic energy and small non-zero vdw energy. - """ - - edge_14 = _get_contact('101M', 0, "CA", 0, "SD") + """MET 0: N - CG, 1-4 pair (at 4.12 A distance). Should have non-zero electrostatic energy and small non-zero vdw energy.""" + edge_14 = _get_contact("101M", 0, "CA", 0, "SD") assert edge_14.features[Efeat.DISTANCE] > cutoff_13 assert edge_14.features[Efeat.DISTANCE] < cutoff_14 - assert edge_14.features[Efeat.VDW] != 0.0, '1-4 pair with 0 vdw energy' - assert abs(edge_14.features[Efeat.VDW]) < 0.1, '1-4 pair with large vdw energy' - assert edge_14.features[Efeat.ELEC] != 0.0, '1-4 pair with 0 electrostatic' - assert edge_14.features[Efeat.COVALENT] == 0.0, '1-4 pair recognized as covalent' + assert edge_14.features[Efeat.VDW] != 0.0, "1-4 pair with 0 vdw energy" + assert abs(edge_14.features[Efeat.VDW]) < 0.1, "1-4 pair with large vdw energy" + assert edge_14.features[Efeat.ELEC] != 0.0, "1-4 pair with 0 electrostatic" + assert edge_14.features[Efeat.COVALENT] == 0.0, "1-4 pair recognized as covalent" def test_14dist_opposing_chains(): - """ChainA PRO 114 CA - ChainB HIS 116 CD2 (3.62 A). Should have non-zero energy despite close contact, because opposing chains. + """ChainA PRO 114 CA - ChainB HIS 116 CD2 (3.62 A). + + Should have non-zero energy despite close contact, because opposing chains. E_vdw for this pair if they were on the same chain: 0.018 - E_vdw for this pair on opposing chains: 0.146 + E_vdw for this pair on opposing chains: 0.146. """ - - opposing_edge = _get_contact('1A0Z', 114, "CA", 116, "CD2", chains=('A', 'B')) + opposing_edge = _get_contact("1A0Z", 114, "CA", 116, "CD2", chains=("A", "B")) assert opposing_edge.features[Efeat.DISTANCE] > cutoff_13 assert opposing_edge.features[Efeat.DISTANCE] < cutoff_14 - assert opposing_edge.features[Efeat.ELEC] > 1.0, f'electrostatic: {opposing_edge.features[Efeat.ELEC]}' - assert opposing_edge.features[Efeat.VDW] > 0.1, f'vdw: {opposing_edge.features[Efeat.VDW]}' + assert opposing_edge.features[Efeat.ELEC] > 1.0, f"electrostatic: {opposing_edge.features[Efeat.ELEC]}" + assert opposing_edge.features[Efeat.VDW] > 0.1, f"vdw: {opposing_edge.features[Efeat.VDW]}" def test_vanderwaals_negative(): - """MET 0 N - ASP 27 CB, very far (29.54 A). Should have negative vanderwaals energy. - """ - - edge_far = _get_contact('101M', 0, "N", 27, "CB") + """MET 0 N - ASP 27 CB, very far (29.54 A). Should have negative vanderwaals energy.""" + edge_far = _get_contact("101M", 0, "N", 27, "CB") assert edge_far.features[Efeat.VDW] < 0.0 def test_vanderwaals_morenegative(): - """MET 0 N - PHE 138 CG, intermediate distance (12.69 A). Should have more negative vanderwaals energy than the far interaction. - """ - - edge_intermediate = _get_contact('101M', 0, "N", 138, "CG") - edge_far = _get_contact('101M', 0, "N", 27, "CB") + """MET 0 N - PHE 138 CG, intermediate distance (12.69 A). Should have more negative vanderwaals energy than the far interaction.""" + edge_intermediate = _get_contact("101M", 0, "N", 138, "CG") + edge_far = _get_contact("101M", 0, "N", 27, "CB") assert edge_intermediate.features[Efeat.VDW] < edge_far.features[Efeat.VDW] def test_edge_distance(): - """Check the edge distances. - """ + """Check the edge distances.""" + edge_close = _get_contact("101M", 0, "N", 0, "CA") + edge_intermediate = _get_contact("101M", 0, "N", 138, "CG") + edge_far = _get_contact("101M", 0, "N", 27, "CB") - edge_close = _get_contact('101M', 0, "N", 0, "CA") - edge_intermediate = _get_contact('101M', 0, "N", 138, "CG") - edge_far = _get_contact('101M', 0, "N", 27, "CB") - - assert ( - edge_close.features[Efeat.DISTANCE] - < edge_intermediate.features[Efeat.DISTANCE] - ), 'close distance > intermediate distance' - assert ( - edge_far.features[Efeat.DISTANCE] - > edge_intermediate.features[Efeat.DISTANCE] - ), 'far distance < intermediate distance' + assert edge_close.features[Efeat.DISTANCE] < edge_intermediate.features[Efeat.DISTANCE], "close distance > intermediate distance" + assert edge_far.features[Efeat.DISTANCE] > edge_intermediate.features[Efeat.DISTANCE], "far distance < intermediate distance" def test_attractive_electrostatic_close(): - """ARG 139 CZ - GLU 136 OE2, very close (5.60 A). Should have attractive electrostatic energy. - """ - - close_attracting_edge = _get_contact('101M', 139, "CZ", 136, "OE2") + """ARG 139 CZ - GLU 136 OE2, very close (5.60 A). Should have attractive electrostatic energy.""" + close_attracting_edge = _get_contact("101M", 139, "CZ", 136, "OE2") assert close_attracting_edge.features[Efeat.ELEC] < 0.0 def test_attractive_electrostatic_far(): - """ARG 139 CZ - ASP 20 OD2, far (24.26 A). Should have attractive more electrostatic energy than above. - """ - - far_attracting_edge = _get_contact('101M', 139, "CZ", 20, "OD2") - close_attracting_edge = _get_contact('101M', 139, "CZ", 136, "OE2") - assert ( - far_attracting_edge.features[Efeat.ELEC] < 0.0 - ), 'far electrostatic > 0' - assert ( - far_attracting_edge.features[Efeat.ELEC] - > close_attracting_edge.features[Efeat.ELEC] - ), 'far electrostatic <= close electrostatic' + """ARG 139 CZ - ASP 20 OD2, far (24.26 A). Should have attractive more electrostatic energy than above.""" + far_attracting_edge = _get_contact("101M", 139, "CZ", 20, "OD2") + close_attracting_edge = _get_contact("101M", 139, "CZ", 136, "OE2") + assert far_attracting_edge.features[Efeat.ELEC] < 0.0, "far electrostatic > 0" + assert far_attracting_edge.features[Efeat.ELEC] > close_attracting_edge.features[Efeat.ELEC], "far electrostatic <= close electrostatic" def test_repulsive_electrostatic(): - """GLU 109 OE2 - GLU 105 OE1 (9.64 A). Should have repulsive electrostatic energy. - """ - - opposing_edge = _get_contact('101M', 109, "OE2", 105, "OE1") + """GLU 109 OE2 - GLU 105 OE1 (9.64 A). Should have repulsive electrostatic energy.""" + opposing_edge = _get_contact("101M", 109, "OE2", 105, "OE1") assert opposing_edge.features[Efeat.ELEC] > 0.0 def test_residue_contact(): - """Check that we can calculate features for residue contacts. - """ - - res_edge = _get_contact('101M', 0, '', 1, '', residue_level = True) - assert res_edge.features[Efeat.DISTANCE] > 0.0, 'distance <= 0' - assert res_edge.features[Efeat.DISTANCE] < 1e5, 'distance > 1e5' - assert res_edge.features[Efeat.ELEC] != 0.0, 'electrostatic == 0' - assert res_edge.features[Efeat.VDW] != 0.0, 'vanderwaals == 0' - assert res_edge.features[Efeat.COVALENT] == 1.0, 'neighboring residues not seen as covalent' + """Check that we can calculate features for residue contacts.""" + res_edge = _get_contact("101M", 0, "", 1, "", residue_level=True) + assert res_edge.features[Efeat.DISTANCE] > 0.0, "distance <= 0" + assert res_edge.features[Efeat.DISTANCE] < 1e5, "distance > 1e5" + assert res_edge.features[Efeat.ELEC] != 0.0, "electrostatic == 0" + assert res_edge.features[Efeat.VDW] != 0.0, "vanderwaals == 0" + assert res_edge.features[Efeat.COVALENT] == 1.0, "neighboring residues not seen as covalent" diff --git a/tests/features/test_exposure.py b/tests/features/test_exposure.py index 3807f2f30..39d7129ca 100644 --- a/tests/features/test_exposure.py +++ b/tests/features/test_exposure.py @@ -8,20 +8,16 @@ def _run_assertions(graph: Graph): - assert np.any( - node.features[Nfeat.HSE] != 0.0 for node in graph.nodes - ), 'hse' + assert np.any(node.features[Nfeat.HSE] != 0.0 for node in graph.nodes), "hse" - assert np.any( - node.features[Nfeat.RESDEPTH] != 0.0 for node in graph.nodes - ), 'resdepth' + assert np.any(node.features[Nfeat.RESDEPTH] != 0.0 for node in graph.nodes), "resdepth" def test_exposure_residue(): pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=8.5, max_edge_length=8.5, ) @@ -33,7 +29,7 @@ def test_exposure_atom(): pdb_path = "tests/data/pdb/1ak4/1ak4.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=4.5, max_edge_length=4.5, ) diff --git a/tests/features/test_irc.py b/tests/features/test_irc.py index d1fd204f9..3505d52b4 100644 --- a/tests/features/test_irc.py +++ b/tests/features/test_irc.py @@ -8,26 +8,19 @@ def _run_assertions(graph: Graph): - assert not np.any( - [np.isnan(node.features[Nfeat.IRCTOTAL]) - for node in graph.nodes] - ), 'nan found' - assert np.any( - [node.features[Nfeat.IRCTOTAL] > 0 - for node in graph.nodes] - ), 'no contacts' + assert not np.any([np.isnan(node.features[Nfeat.IRCTOTAL]) for node in graph.nodes]), "nan found" + assert np.any([node.features[Nfeat.IRCTOTAL] > 0 for node in graph.nodes]), "no contacts" assert np.all( - [node.features[Nfeat.IRCTOTAL] == sum(node.features[IRCtype] for IRCtype in Nfeat.IRC_FEATURES[:-1]) - for node in graph.nodes] - ), 'incorrect total' + [node.features[Nfeat.IRCTOTAL] == sum(node.features[IRCtype] for IRCtype in Nfeat.IRC_FEATURES[:-1]) for node in graph.nodes] + ), "incorrect total" def test_irc_residue(): pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=8.5, max_edge_length=8.5, ) @@ -39,7 +32,7 @@ def test_irc_atom(): pdb_path = "tests/data/pdb/1A0Z/1A0Z.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=4.5, max_edge_length=4.5, ) diff --git a/tests/features/test_secondary_structure.py b/tests/features/test_secondary_structure.py index 8dfc6cf13..3af81666b 100644 --- a/tests/features/test_secondary_structure.py +++ b/tests/features/test_secondary_structure.py @@ -1,91 +1,91 @@ import numpy as np from deeprank2.domain import nodestorage as Nfeat -from deeprank2.features.secondary_structure import (SecondarySctructure, - _classify_secstructure, - add_features) +from deeprank2.features.secondary_structure import ( + SecondarySctructure, + _classify_secstructure, + add_features, +) from . import build_testgraph def test_secondary_structure_residue(): - test_case = '9api' # properly formatted pdb file + test_case = "9api" # properly formatted pdb file pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=10, max_edge_length=10, ) add_features(pdb_path, graph) # Create a list of node information (residue number, chain ID, and secondary structure features) - node_info_list = [[node.id.number, - node.id.chain.id, - node.features[Nfeat.SECSTRUCT]] - for node in graph.nodes] + node_info_list = [[node.id.number, node.id.chain.id, node.features[Nfeat.SECSTRUCT]] for node in graph.nodes] print(node_info_list) # Check that all nodes have exactly 1 secondary structure type - assert np.all([np.sum(node.features[Nfeat.SECSTRUCT]) == 1.0 for node in graph.nodes]), 'one hot encoding error' + assert np.all([np.sum(node.features[Nfeat.SECSTRUCT]) == 1.0 for node in graph.nodes]), "one hot encoding error" # check ground truth examples residues = [ - (267, 'A', ' ', SecondarySctructure.COIL), - (46, 'A', 'S', SecondarySctructure.COIL), - (104, 'A', 'T', SecondarySctructure.COIL), + (267, "A", " ", SecondarySctructure.COIL), + (46, "A", "S", SecondarySctructure.COIL), + (104, "A", "T", SecondarySctructure.COIL), # (None, '', 'P', SecondarySctructure.COIL), # not found in test file - (194, 'A', 'B', SecondarySctructure.STRAND), - (385, 'B', 'E', SecondarySctructure.STRAND), - (235, 'A', 'G', SecondarySctructure.HELIX), - (263, 'A', 'H', SecondarySctructure.HELIX), + (194, "A", "B", SecondarySctructure.STRAND), + (385, "B", "E", SecondarySctructure.STRAND), + (235, "A", "G", SecondarySctructure.HELIX), + (263, "A", "H", SecondarySctructure.HELIX), # (0, '', 'I', SecondarySctructure.HELIX), # not found in test file ] for res in residues: node_list = [node_info for node_info in node_info_list if (node_info[0] == res[0] and node_info[1] == res[1])] - assert len(node_list) > 0, f'no nodes detected in {res[1]} {res[0]}' + assert len(node_list) > 0, f"no nodes detected in {res[1]} {res[0]}" assert np.all( - [np.array_equal(node_info[2], _classify_secstructure(res[2]).onehot) - for node_info in node_list] - ), f'Ground truth examples: res {res[1]} {res[0]} is not {(res[2])}.' + [np.array_equal(node_info[2], _classify_secstructure(res[2]).onehot) for node_info in node_list] + ), f"Ground truth examples: res {res[1]} {res[0]} is not {(res[2])}." assert np.all( - [np.array_equal(node_info[2], res[3].onehot) - for node_info in node_list] - ), f'Ground truth examples: res {res[1]} {res[0]} is not {res[3]}.' + [np.array_equal(node_info[2], res[3].onehot) for node_info in node_list] + ), f"Ground truth examples: res {res[1]} {res[0]} is not {res[3]}." def test_secondary_structure_atom(): - test_case = '1ak4' # ATOM list + test_case = "1ak4" # ATOM list pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=4.5, max_edge_length=4.5, ) add_features(pdb_path, graph) # Create a list of node information (residue number, chain ID, and secondary structure features) - node_info_list = [[node.id.residue.number, - node.id.residue.chain.id, - node.features[Nfeat.SECSTRUCT]] - for node in graph.nodes] + node_info_list = [ + [ + node.id.residue.number, + node.id.residue.chain.id, + node.features[Nfeat.SECSTRUCT], + ] + for node in graph.nodes + ] # check entire DSSP file # residue number @ pos 5-10, chain_id @ pos 11, secondary structure @ pos 16 - with open(f'tests/data/dssp/{test_case}.dssp.txt', encoding="utf8") as file: + with open(f"tests/data/dssp/{test_case}.dssp.txt", encoding="utf8") as file: dssp_lines = [line.rstrip() for line in file] for node in node_info_list: - dssp_line = [line for line in dssp_lines - if (line[5:10] == str(node[0]).rjust(5) and line[11] == node[1])][0] + dssp_line = next(line for line in dssp_lines if (line[5:10] == str(node[0]).rjust(5) and line[11] == node[1])) dssp_code = dssp_line[16] - if dssp_code in [' ', 'S', 'T']: - assert np.array_equal(node[2],SecondarySctructure.COIL.onehot), f'Full file test: res {node[1]}{node[0]} is not a COIL' - elif dssp_code in ['B', 'E']: - assert np.array_equal(node[2],SecondarySctructure.STRAND.onehot), f'Full file test: res {node[1]}{node[0]} is not a STRAND' - elif dssp_code in ['G', 'H', 'I']: - assert np.array_equal(node[2],SecondarySctructure.HELIX.onehot), f'Full file test: res {node[1]}{node[0]} is not a HELIX' + if dssp_code in [" ", "S", "T"]: + assert np.array_equal(node[2], SecondarySctructure.COIL.onehot), f"Full file test: res {node[1]}{node[0]} is not a COIL" + elif dssp_code in ["B", "E"]: + assert np.array_equal(node[2], SecondarySctructure.STRAND.onehot), f"Full file test: res {node[1]}{node[0]} is not a STRAND" + elif dssp_code in ["G", "H", "I"]: + assert np.array_equal(node[2], SecondarySctructure.HELIX.onehot), f"Full file test: res {node[1]}{node[0]} is not a HELIX" else: - raise ValueError(f'Unexpected secondary structure type found at {node[1]}{node[0]}') + raise ValueError(f"Unexpected secondary structure type found at {node[1]}{node[0]}") diff --git a/tests/features/test_surfacearea.py b/tests/features/test_surfacearea.py index a82ef4d33..e4efc3912 100644 --- a/tests/features/test_surfacearea.py +++ b/tests/features/test_surfacearea.py @@ -17,11 +17,7 @@ def _find_residue_node(graph, chain_id, residue_number): def _find_atom_node(graph, chain_id, residue_number, atom_name): for node in graph.nodes: atom = node.id - if ( - atom.residue.chain.id == chain_id - and atom.residue.number == residue_number - and atom.name == atom_name - ): + if atom.residue.chain.id == chain_id and atom.residue.number == residue_number and atom.name == atom_name: return node raise ValueError(f"Not found: {chain_id} {residue_number} {atom_name}") @@ -30,7 +26,7 @@ def test_bsa_residue(): pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=8.5, max_edge_length=8.5, ) @@ -45,7 +41,7 @@ def test_bsa_atom(): pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=4.5, max_edge_length=4.5, ) @@ -60,7 +56,7 @@ def test_sasa_residue(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='residue', + detail="residue", influence_radius=10, max_edge_length=10, central_res=108, @@ -68,9 +64,7 @@ def test_sasa_residue(): add_features(pdb_path, graph) # check for NaN - assert not any( - np.isnan(node.features[Nfeat.SASA]) for node in graph.nodes - ) + assert not any(np.isnan(node.features[Nfeat.SASA]) for node in graph.nodes) # surface residues should have large area surface_residue_node = _find_residue_node(graph, "A", 105) @@ -85,7 +79,7 @@ def test_sasa_atom(): pdb_path = "tests/data/pdb/101M/101M.pdb" graph, _ = build_testgraph( pdb_path=pdb_path, - detail='atom', + detail="atom", influence_radius=10, max_edge_length=10, central_res=108, @@ -93,9 +87,7 @@ def test_sasa_atom(): add_features(pdb_path, graph) # check for NaN - assert not any( - np.isnan(node.features[Nfeat.SASA]) for node in graph.nodes - ) + assert not any(np.isnan(node.features[Nfeat.SASA]) for node in graph.nodes) # surface atoms should have large area surface_atom_node = _find_atom_node(graph, "A", 105, "OE2") diff --git a/tests/molstruct/__init__.py b/tests/molstruct/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/molstruct/test_structure.py b/tests/molstruct/test_structure.py index 2692cf18e..77fe8d5ec 100644 --- a/tests/molstruct/test_structure.py +++ b/tests/molstruct/test_structure.py @@ -1,4 +1,3 @@ - import pickle from multiprocessing.connection import _ForkingPickler @@ -13,7 +12,7 @@ def _get_structure(path) -> PDBStructure: try: structure = get_structure(pdb, "101M") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) assert structure is not None @@ -21,11 +20,10 @@ def _get_structure(path) -> PDBStructure: def test_serialization_pickle(): - structure = _get_structure("tests/data/pdb/101M/101M.pdb") s = pickle.dumps(structure) - loaded_structure = pickle.loads(s) + loaded_structure = pickle.loads(s) # noqa: S301 (suspicious-pickle-usage) assert loaded_structure == structure assert loaded_structure.get_chain("A") == structure.get_chain("A") @@ -35,7 +33,6 @@ def test_serialization_pickle(): def test_serialization_fork(): - structure = _get_structure("tests/data/pdb/101M/101M.pdb") s = _ForkingPickler.dumps(structure) diff --git a/tests/perf/__init__.py b/tests/perf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/perf/ppi_perf.py b/tests/perf/ppi_perf.py index 3fa221206..36b2f73cd 100644 --- a/tests/perf/ppi_perf.py +++ b/tests/perf/ppi_perf.py @@ -5,22 +5,29 @@ from os import listdir from os.path import isfile, join -import numpy +import numpy as np import pandas as pd -from deeprank2.features import (components, contact, exposure, irc, - secondary_structure, surfacearea) +from deeprank2.features import ( + components, + contact, + exposure, + irc, + secondary_structure, + surfacearea, +) from deeprank2.query import ProteinProteinInterfaceAtomicQuery, QueryCollection from deeprank2.utils.grid import GridSettings, MapMethod #################### PARAMETERS #################### interface_distance_cutoff = 5.5 # max distance in Å between two interacting residues/atoms of two proteins -grid_settings = GridSettings( # None if you don't want grids +grid_settings = GridSettings( # None if you don't want grids # the number of points on the x, y, z edges of the cube - points_counts = [35, 30, 30], + points_counts=[35, 30, 30], # x, y, z sizes of the box in Å - sizes = [1.0, 1.0, 1.0]) -grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids + sizes=[1.0, 1.0, 1.0], +) +grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids # grid_settings = None # grid_map_method = None feature_modules = [components, contact, exposure, irc, secondary_structure, surfacearea] @@ -33,19 +40,19 @@ if not os.path.exists(os.path.join(processed_data_path, "atomic")): os.makedirs(os.path.join(processed_data_path, "atomic")) + def get_pdb_files_and_target_data(data_path): csv_data = pd.read_csv(os.path.join(data_path, "BA_values.csv")) - pdb_files = glob.glob(os.path.join(data_path, "pdb", '*.pdb')) + pdb_files = glob.glob(os.path.join(data_path, "pdb", "*.pdb")) pdb_files.sort() - pdb_ids_csv = [pdb_file.split('/')[-1].split('.')[0] for pdb_file in pdb_files] - csv_data_indexed = csv_data.set_index('ID') + pdb_ids_csv = [pdb_file.split("/")[-1].split(".")[0] for pdb_file in pdb_files] + csv_data_indexed = csv_data.set_index("ID") csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv] - bas = csv_data_indexed.measurement_value.values.tolist() + bas = csv_data_indexed.measurement_value.to_numpy().tolist() return pdb_files, bas -if __name__=='__main__': - +if __name__ == "__main__": timings = [] count = 0 pdb_files, bas = get_pdb_files_and_target_data(data_path) @@ -54,32 +61,35 @@ def get_pdb_files_and_target_data(data_path): queries = QueryCollection() queries.add( ProteinProteinInterfaceAtomicQuery( - pdb_path = pdb_file, - chain_id1 = "M", - chain_id2 = "P", - distance_cutoff = interface_distance_cutoff, - targets = { - 'binary': int(float(bas[i]) <= 500), # binary target value - 'BA': bas[i], # continuous target value - })) + pdb_path=pdb_file, + chain_id1="M", + chain_id2="P", + distance_cutoff=interface_distance_cutoff, + targets={ + "binary": int(float(bas[i]) <= 500), # binary target value + "BA": bas[i], # continuous target value + }, + ) + ) start = time.perf_counter() queries.process( - prefix = os.path.join(processed_data_path, "atomic", "proc"), - feature_modules = feature_modules, - cpu_count = cpu_count, - combine_output = False, - grid_settings = grid_settings, - grid_map_method = grid_map_method) + prefix=os.path.join(processed_data_path, "atomic", "proc"), + feature_modules=feature_modules, + cpu_count=cpu_count, + combine_output=False, + grid_settings=grid_settings, + grid_map_method=grid_map_method, + ) end = time.perf_counter() elapsed = end - start timings.append(elapsed) - print(f'Elapsed time: {elapsed:.6f} seconds.\n') + print(f"Elapsed time: {elapsed:.6f} seconds.\n") - timings = numpy.array(timings) + timings = np.array(timings) print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, "atomic")}.') - print(f'Avg: {numpy.mean(timings):.6f} seconds.') - print(f'Std: {numpy.std(timings):.6f} seconds.\n') + print(f"Avg: {np.mean(timings):.6f} seconds.") + print(f"Std: {np.std(timings):.6f} seconds.\n") proc_files_path = os.path.join(processed_data_path, "atomic") proc_files = [f for f in listdir(proc_files_path) if isfile(join(proc_files_path, f))] @@ -87,8 +97,8 @@ def get_pdb_files_and_target_data(data_path): for proc_file in proc_files: file_size = os.path.getsize(os.path.join(proc_files_path, proc_file)) mb_file_size = file_size / (10**6) - print(f'Size of {proc_file}: {mb_file_size} MB.\n') + print(f"Size of {proc_file}: {mb_file_size} MB.\n") mem_sizes.append(mb_file_size) - mem_sizes = numpy.array(mem_sizes) - print(f'Avg: {numpy.mean(mem_sizes):.6f} MB.') - print(f'Std: {numpy.std(mem_sizes):.6f} MB.') + mem_sizes = np.array(mem_sizes) + print(f"Avg: {np.mean(mem_sizes):.6f} MB.") + print(f"Std: {np.std(mem_sizes):.6f} MB.") diff --git a/tests/perf/srv_perf.py b/tests/perf/srv_perf.py index 900c1a1ca..dd1d11f65 100644 --- a/tests/perf/srv_perf.py +++ b/tests/perf/srv_perf.py @@ -5,39 +5,75 @@ from os import listdir from os.path import isfile, join -import numpy +import numpy as np import pandas as pd -from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine, - aspartate, cysteine, glutamate, - glutamine, glycine, histidine, - isoleucine, leucine, lysine, - methionine, phenylalanine, proline, - serine, threonine, tryptophan, - tyrosine, valine) -from deeprank2.features import (components, contact, exposure, irc, - secondary_structure, surfacearea) +from deeprank2.domain.aminoacidlist import ( + alanine, + arginine, + asparagine, + aspartate, + cysteine, + glutamate, + glutamine, + glycine, + histidine, + isoleucine, + leucine, + lysine, + methionine, + phenylalanine, + proline, + serine, + threonine, + tryptophan, + tyrosine, + valine, +) +from deeprank2.features import ( + components, + contact, + exposure, + irc, + secondary_structure, + surfacearea, +) from deeprank2.query import QueryCollection, SingleResidueVariantResidueQuery from deeprank2.utils.grid import GridSettings, MapMethod -aa_dict = {"ALA": alanine, "CYS": cysteine, "ASP": aspartate, - "GLU": glutamate, "PHE": phenylalanine, "GLY": glycine, - "HIS": histidine, "ILE": isoleucine, "LYS": lysine, - "LEU": leucine, "MET": methionine, "ASN": asparagine, - "PRO": proline, "GLN": glutamine, "ARG": arginine, - "SER": serine, "THR": threonine, "VAL": valine, - "TRP": tryptophan, "TYR": tyrosine - } +aa_dict = { + "ALA": alanine, + "CYS": cysteine, + "ASP": aspartate, + "GLU": glutamate, + "PHE": phenylalanine, + "GLY": glycine, + "HIS": histidine, + "ILE": isoleucine, + "LYS": lysine, + "LEU": leucine, + "MET": methionine, + "ASN": asparagine, + "PRO": proline, + "GLN": glutamine, + "ARG": arginine, + "SER": serine, + "THR": threonine, + "VAL": valine, + "TRP": tryptophan, + "TYR": tyrosine, +} #################### PARAMETERS #################### radius = 10.0 distance_cutoff = 5.5 -grid_settings = GridSettings( # None if you don't want grids +grid_settings = GridSettings( # None if you don't want grids # the number of points on the x, y, z edges of the cube - points_counts = [35, 30, 30], + points_counts=[35, 30, 30], # x, y, z sizes of the box in Å - sizes = [1.0, 1.0, 1.0]) -grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids + sizes=[1.0, 1.0, 1.0], +) +grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids # grid_settings = None # grid_map_method = None feature_modules = [components, contact, exposure, irc, surfacearea, secondary_structure] @@ -50,62 +86,70 @@ if not os.path.exists(os.path.join(processed_data_path, "atomic")): os.makedirs(os.path.join(processed_data_path, "atomic")) + def get_pdb_files_and_target_data(data_path): csv_data = pd.read_csv(os.path.join(data_path, "srv_target_values.csv")) # before running this script change .ent to .pdb - pdb_files = glob.glob(os.path.join(data_path, "pdb", '*.pdb')) + pdb_files = glob.glob(os.path.join(data_path, "pdb", "*.pdb")) pdb_files.sort() - pdb_id = [os.path.basename(pdb_file).split('.')[0] for pdb_file in pdb_files] - csv_data['pdb_id'] = csv_data['pdb_file'].apply(lambda x: x.split('.')[0]) - csv_data_indexed = csv_data.set_index('pdb_id') + pdb_id = [os.path.basename(pdb_file).split(".")[0] for pdb_file in pdb_files] + csv_data["pdb_id"] = csv_data["pdb_file"].apply(lambda x: x.split(".")[0]) + csv_data_indexed = csv_data.set_index("pdb_id") csv_data_indexed = csv_data_indexed.loc[pdb_id] - res_numbers = csv_data_indexed.res_number.values.tolist() - res_wildtypes = csv_data_indexed.res_wildtype.values.tolist() - res_variants = csv_data_indexed.res_variant.values.tolist() - targets = csv_data_indexed.target.values.tolist() - pdb_names = csv_data_indexed.index.values.tolist() + res_numbers = csv_data_indexed.res_number.to_numpy().tolist() + res_wildtypes = csv_data_indexed.res_wildtype.to_numpy().tolist() + res_variants = csv_data_indexed.res_variant.to_numpy().tolist() + targets = csv_data_indexed.target.to_numpy().tolist() + pdb_names = csv_data_indexed.index.to_numpy().tolist() pdb_files = [data_path + "/pdb/" + pdb_name + ".pdb" for pdb_name in pdb_names] return pdb_files, res_numbers, res_wildtypes, res_variants, targets -if __name__=='__main__': - +if __name__ == "__main__": timings = [] count = 0 - pdb_files, res_numbers, res_wildtypes, res_variants, targets = get_pdb_files_and_target_data(data_path) + ( + pdb_files, + res_numbers, + res_wildtypes, + res_variants, + targets, + ) = get_pdb_files_and_target_data(data_path) for i, pdb_file in enumerate(pdb_files): queries = QueryCollection() queries.add( SingleResidueVariantResidueQuery( - pdb_path = pdb_file, - chain_id = "A", - residue_number = res_numbers[i], - insertion_code = None, - wildtype_amino_acid = aa_dict[res_wildtypes[i]], - variant_amino_acid = aa_dict[res_variants[i]], - targets = {'binary': targets[i]}, - radius = radius, - distance_cutoff = distance_cutoff, - )) + pdb_path=pdb_file, + chain_id="A", + residue_number=res_numbers[i], + insertion_code=None, + wildtype_amino_acid=aa_dict[res_wildtypes[i]], + variant_amino_acid=aa_dict[res_variants[i]], + targets={"binary": targets[i]}, + radius=radius, + distance_cutoff=distance_cutoff, + ) + ) start = time.perf_counter() queries.process( - prefix = os.path.join(processed_data_path, "atomic", "proc"), - feature_modules = feature_modules, - cpu_count = cpu_count, - combine_output = False, - grid_settings = grid_settings, - grid_map_method = grid_map_method) + prefix=os.path.join(processed_data_path, "atomic", "proc"), + feature_modules=feature_modules, + cpu_count=cpu_count, + combine_output=False, + grid_settings=grid_settings, + grid_map_method=grid_map_method, + ) end = time.perf_counter() elapsed = end - start timings.append(elapsed) - print(f'Elapsed time: {elapsed:.6f} seconds.\n') + print(f"Elapsed time: {elapsed:.6f} seconds.\n") - timings = numpy.array(timings) + timings = np.array(timings) print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, "atomic")}.') - print(f'Avg: {numpy.mean(timings):.6f} seconds.') - print(f'Std: {numpy.std(timings):.6f} seconds.\n') + print(f"Avg: {np.mean(timings):.6f} seconds.") + print(f"Std: {np.std(timings):.6f} seconds.\n") proc_files_path = os.path.join(processed_data_path, "atomic") proc_files = [f for f in listdir(proc_files_path) if isfile(join(proc_files_path, f))] @@ -113,8 +157,8 @@ def get_pdb_files_and_target_data(data_path): for proc_file in proc_files: file_size = os.path.getsize(os.path.join(proc_files_path, proc_file)) mb_file_size = file_size / (10**6) - print(f'Size of {proc_file}: {mb_file_size} MB.\n') + print(f"Size of {proc_file}: {mb_file_size} MB.\n") mem_sizes.append(mb_file_size) - mem_sizes = numpy.array(mem_sizes) - print(f'Avg: {numpy.mean(mem_sizes):.6f} MB.') - print(f'Std: {numpy.std(mem_sizes):.6f} MB.') + mem_sizes = np.array(mem_sizes) + print(f"Avg: {np.mean(mem_sizes):.6f} MB.") + print(f"Std: {np.std(mem_sizes):.6f} MB.") diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a33995194..c5f42d45c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -15,82 +15,84 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets -node_feats = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM] - -def _compute_features_manually( # noqa: MC0001, pylint: disable=too-many-locals +node_feats = [ + Nfeat.RESTYPE, + Nfeat.POLARITY, + Nfeat.BSA, + Nfeat.RESDEPTH, + Nfeat.HSE, + Nfeat.INFOCONTENT, + Nfeat.PSSM, +] + + +def _compute_features_manually( hdf5_path: str, features_transform: dict, - feat: str + feat: str, ): - """ - This function returns the feature specified read from the hdf5 file, - after applying manually features_transform dict. It returns its mean - and its std after having applied eventual transformations. + """Return specified feature. + + This function returns the feature specified read from the hdf5 file, after applying manually features_transform dict. + It returns its mean and its std after having applied eventual transformations. Multi-channels features are returned as an array with multiple channels. """ - - with h5py.File(hdf5_path, 'r') as f: + with h5py.File(hdf5_path, "r") as f: entry_names = [entry for entry, _ in f.items()] - mol_key = list(f.keys())[0] + mol_key = next(iter(f.keys())) # read available node features available_node_features = list(f[f"{mol_key}/{Nfeat.NODE}/"].keys()) - available_node_features = [key for key in available_node_features if key[0] != '_'] # ignore metafeatures + available_node_features = [key for key in available_node_features if key[0] != "_"] # ignore metafeatures # read available edge features available_edge_features = list(f[f"{mol_key}/{Efeat.EDGE}/"].keys()) - available_edge_features = [key for key in available_edge_features if key[0] != '_'] # ignore metafeatures + available_edge_features = [key for key in available_edge_features if key[0] != "_"] # ignore metafeatures - if 'all' in features_transform: - transform = features_transform.get('all', {}).get('transform') + if "all" in features_transform: + transform = features_transform.get("all", {}).get("transform") else: - transform = features_transform.get(feat, {}).get('transform') + transform = features_transform.get(feat, {}).get("transform") if feat in available_node_features: feat_values = [ - f[entry_name][Nfeat.NODE][feat][:] - if f[entry_name][Nfeat.NODE][feat][()].ndim == 1 - else f[entry_name][Nfeat.NODE][feat][()] for entry_name in entry_names] + f[entry_name][Nfeat.NODE][feat][:] if f[entry_name][Nfeat.NODE][feat][()].ndim == 1 else f[entry_name][Nfeat.NODE][feat][()] + for entry_name in entry_names + ] elif feat in available_edge_features: feat_values = [ - f[entry_name][Efeat.EDGE][feat][:] - if f[entry_name][Efeat.EDGE][feat][()].ndim == 1 - else f[entry_name][Efeat.EDGE][feat][()] for entry_name in entry_names] + f[entry_name][Efeat.EDGE][feat][:] if f[entry_name][Efeat.EDGE][feat][()].ndim == 1 else f[entry_name][Efeat.EDGE][feat][()] + for entry_name in entry_names + ] else: - print(f'Feat {feat} not present in the file.') + print(f"Feat {feat} not present in the file.") - #apply transformation + # apply transformation if transform: - feat_values=[transform(row) for row in feat_values] + feat_values = [transform(row) for row in feat_values] arr = np.array(np.concatenate(feat_values)) - mean = np.round(np.nanmean(arr, axis=0), 1) if isinstance(arr[0], np.ndarray) \ - else round(np.nanmean(arr), 1) - dev = np.round(np.nanstd(arr, axis=0), 1) if isinstance(arr[0], np.ndarray) \ - else round(np.nanstd(arr), 1) + mean = np.round(np.nanmean(arr, axis=0), 1) if isinstance(arr[0], np.ndarray) else round(np.nanmean(arr), 1) + dev = np.round(np.nanstd(arr, axis=0), 1) if isinstance(arr[0], np.ndarray) else round(np.nanstd(arr), 1) return arr, mean, dev -def _compute_features_with_get( - hdf5_path: str, - dataset: GraphDataset - ): + +def _compute_features_with_get(hdf5_path: str, dataset: GraphDataset): # This function computes features using the Dataset `get` method, # so as they will be seen by the network. It returns a dictionary # whose keys are the features' names and values are the features' values. # Multi-channels features are splitted into different keys - with h5py.File(hdf5_path, 'r') as f5: - grp = f5[list(f5.keys())[0]] + with h5py.File(hdf5_path, "r") as f5: + grp = f5[next(iter(f5.keys()))] # getting all node features values tensor_idx = 0 features_dict = {} for feat in dataset.node_features: vals = grp[f"{Nfeat.NODE}/{feat}"][()] - if vals.ndim == 1: # features with only one channel - arr = [] - for entry_idx in range(len(dataset)): - arr.append(dataset.get(entry_idx).x[:, tensor_idx]) + if vals.ndim == 1: # features with only one channel + arr = [dataset.get(entry_idx).x[:, tensor_idx] for entry_idx in range(len(dataset))] arr = np.concatenate(arr) features_dict[feat] = arr tensor_idx += 1 @@ -101,13 +103,13 @@ def _compute_features_with_get( arr.append(dataset.get(entry_idx).x[:, tensor_idx]) tensor_idx += 1 arr = np.concatenate(arr) - features_dict[feat + f'_{ch}'] = arr + features_dict[feat + f"_{ch}"] = arr # getting all edge features values tensor_idx = 0 for feat in dataset.edge_features: vals = grp[f"{Efeat.EDGE}/{feat}"][()] - if vals.ndim == 1: # features with only one channel + if vals.ndim == 1: # features with only one channel arr = [] for entry_idx in range(len(dataset)): arr.append(dataset.get(entry_idx).edge_attr[:, tensor_idx]) @@ -121,9 +123,10 @@ def _compute_features_with_get( arr.append(dataset.get(entry_idx).edge_attr[:, tensor_idx]) tensor_idx += 1 arr = np.concatenate(arr) - features_dict[feat + f'_{ch}'] = arr + features_dict[feat + f"_{ch}"] = arr return features_dict + def _check_inherited_params( inherited_params: list[str], dataset_train: GraphDataset | GridDataset, @@ -135,28 +138,43 @@ def _check_inherited_params( for param in inherited_params: assert dataset_test_vars[param] == dataset_train_vars[param] + class TestDataSet(unittest.TestCase): def setUp(self): self.hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5" def test_collates_entry_names_datasets(self): - - for dataset_name, dataset in [("GraphDataset", GraphDataset(self.hdf5_path, - node_features = node_feats, - edge_features = [Efeat.DISTANCE], - target = targets.IRMSD)), - ("GridDataset", GridDataset(self.hdf5_path, - features = [Efeat.VDW], - target = targets.IRMSD))]: - + for dataset_name, dataset in [ + ( + "GraphDataset", + GraphDataset( + self.hdf5_path, + node_features=node_feats, + edge_features=[Efeat.DISTANCE], + target=targets.IRMSD, + ), + ), + ( + "GridDataset", + GridDataset( + self.hdf5_path, + features=[Efeat.VDW], + target=targets.IRMSD, + ), + ), + ]: entry_names = [] for batch_data in DataLoader(dataset, batch_size=2, shuffle=True): entry_names += batch_data.entry_names - assert set(entry_names) == set(['residue-ppi-1ATN_1w:A-B', - 'residue-ppi-1ATN_2w:A-B', - 'residue-ppi-1ATN_3w:A-B', - 'residue-ppi-1ATN_4w:A-B']), f"entry names of {dataset_name} were not collated correctly" + assert set(entry_names) == set( # noqa: C405 (unnecessary-literal-set) + [ + "residue-ppi-1ATN_1w:A-B", + "residue-ppi-1ATN_2w:A-B", + "residue-ppi-1ATN_3w:A-B", + "residue-ppi-1ATN_4w:A-B", + ], + ), f"entry names of {dataset_name} were not collated correctly" def test_datasets(self): dataset_graph = GraphDataset( @@ -164,14 +182,14 @@ def test_datasets(self): subset=None, node_features=node_feats, edge_features=[Efeat.DISTANCE], - target=targets.IRMSD + target=targets.IRMSD, ) dataset_grid = GridDataset( - hdf5_path = self.hdf5_path, - subset = None, - features = [Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], - target = targets.IRMSD + hdf5_path=self.hdf5_path, + subset=None, + features=[Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], + target=targets.IRMSD, ) assert len(dataset_graph) == 4 @@ -181,74 +199,92 @@ def test_datasets(self): def test_regression_griddataset(self): dataset = GridDataset( - hdf5_path = self.hdf5_path, - features = [Efeat.VDW, Efeat.ELEC], - target = targets.IRMSD + hdf5_path=self.hdf5_path, + features=[Efeat.VDW, Efeat.ELEC], + target=targets.IRMSD, ) assert len(dataset) == 4 # 1 entry, 2 features with grid box dimensions - assert dataset[0].x.shape == (1, 2, 20, 20, 20), f"got features shape {dataset[0].x.shape}" + assert dataset[0].x.shape == ( + 1, + 2, + 20, + 20, + 20, + ), f"got features shape {dataset[0].x.shape}" # 1 entry with rmsd value assert dataset[0].y.shape == (1,) def test_classification_griddataset(self): dataset = GridDataset( - hdf5_path = self.hdf5_path, - features = [Efeat.VDW, Efeat.ELEC], - target = targets.BINARY + hdf5_path=self.hdf5_path, + features=[Efeat.VDW, Efeat.ELEC], + target=targets.BINARY, ) assert len(dataset) == 4 # 1 entry, 2 features with grid box dimensions - assert dataset[0].x.shape == (1, 2, 20, 20, 20), f"got features shape {dataset[0].x.shape}" + assert dataset[0].x.shape == ( + 1, + 2, + 20, + 20, + 20, + ), f"got features shape {dataset[0].x.shape}" # 1 entry with class value assert dataset[0].y.shape == (1,) def test_inherit_info_dataset_train_griddataset(self): - dataset_train = GridDataset( - hdf5_path = self.hdf5_path, - features = [Efeat.VDW, Efeat.ELEC], - target = targets.BINARY, - target_transform = False, - task = targets.CLASSIF, - classes = None + hdf5_path=self.hdf5_path, + features=[Efeat.VDW, Efeat.ELEC], + target=targets.BINARY, + target_transform=False, + task=targets.CLASSIF, + classes=None, ) dataset_test = GridDataset( - hdf5_path = self.hdf5_path, - train_source = dataset_train + hdf5_path=self.hdf5_path, + train_source=dataset_train, ) - _check_inherited_params(dataset_test.inherited_params, dataset_train, dataset_test) + _check_inherited_params( + dataset_test.inherited_params, + dataset_train, + dataset_test, + ) dataset_test = GridDataset( - hdf5_path = self.hdf5_path, - train_source = dataset_train, - features = [Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], - target = targets.IRMSD, - target_transform = True, - task = targets.REGRESS, - classes = None + hdf5_path=self.hdf5_path, + train_source=dataset_train, + features=[Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], + target=targets.IRMSD, + target_transform=True, + task=targets.REGRESS, + classes=None, ) - _check_inherited_params(dataset_test.inherited_params, dataset_train, dataset_test) + _check_inherited_params( + dataset_test.inherited_params, + dataset_train, + dataset_test, + ) def test_inherit_info_pretrained_model_griddataset(self): - # Test the inheritance not giving in any parameters pretrained_model = "tests/data/pretrained/testing_grid_model.pth.tar" dataset_test = GridDataset( - hdf5_path = self.hdf5_path, - train_source = pretrained_model + hdf5_path=self.hdf5_path, + train_source=pretrained_model, ) - data = torch.load(pretrained_model, map_location=torch.device('cpu')) + data = torch.load(pretrained_model, map_location=torch.device("cpu")) dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: @@ -256,13 +292,13 @@ def test_inherit_info_pretrained_model_griddataset(self): # Test that even when different parameters from the training data are given, the inheritance works dataset_test = GridDataset( - hdf5_path = self.hdf5_path, - train_source = pretrained_model, - features = [Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], - target = targets.IRMSD, - target_transform = True, - task = targets.REGRESS, - classes = None + hdf5_path=self.hdf5_path, + train_source=pretrained_model, + features=[Efeat.DISTANCE, Efeat.COVALENT, Efeat.SAMECHAIN], + target=targets.IRMSD, + target_transform=True, + task=targets.REGRESS, + classes=None, ) ## features, target, target_transform, task, and classes @@ -277,56 +313,49 @@ def test_no_target_dataset_griddataset(self): pretrained_model = "tests/data/pretrained/testing_grid_model.pth.tar" dataset = GridDataset( - hdf5_path = hdf5_no_target, - train_source = pretrained_model + hdf5_path=hdf5_no_target, + train_source=pretrained_model, ) assert dataset.target is not None assert dataset.get(0).y is None # no target set, training mode - with self.assertRaises(ValueError): - dataset = GridDataset( - hdf5_path = hdf5_no_target, - ) + with pytest.raises(ValueError): + dataset = GridDataset(hdf5_path=hdf5_no_target) # target set, but not present in the file - with self.assertRaises(ValueError): - dataset = GridDataset( - hdf5_path = hdf5_target, - target = 'CAPRI' - ) + with pytest.raises(ValueError): + dataset = GridDataset(hdf5_path=hdf5_target, target="CAPRI") def test_filter_griddataset(self): - # filtering out all values - with self.assertRaises(IndexError): + with pytest.raises(IndexError): GridDataset( hdf5_path=self.hdf5_path, subset=None, target=targets.IRMSD, - target_filter={targets.IRMSD: "<10"} + target_filter={targets.IRMSD: "<10"}, ) # filter our some values dataset = GridDataset( hdf5_path=self.hdf5_path, subset=None, target=targets.IRMSD, - target_filter={targets.IRMSD: ">15"} + target_filter={targets.IRMSD: ">15"}, ) assert len(dataset) == 3 def test_filter_graphdataset(self): - # filtering out all values - with self.assertRaises(IndexError): + with pytest.raises(IndexError): GraphDataset( hdf5_path=self.hdf5_path, subset=None, node_features=node_feats, edge_features=[Efeat.DISTANCE], target=targets.IRMSD, - target_filter={targets.IRMSD: "<10"} + target_filter={targets.IRMSD: "<10"}, ) # filter our some values dataset = GraphDataset( @@ -335,7 +364,7 @@ def test_filter_graphdataset(self): node_features=node_feats, edge_features=[Efeat.DISTANCE], target=targets.IRMSD, - target_filter={targets.IRMSD: ">15"} + target_filter={targets.IRMSD: ">15"}, ) assert len(dataset) == 3 @@ -344,7 +373,7 @@ def test_multi_file_graphdataset(self): hdf5_path=["tests/data/hdf5/train.hdf5", "tests/data/hdf5/valid.hdf5"], node_features=node_feats, edge_features=[Efeat.DISTANCE], - target=targets.BINARY + target=targets.BINARY, ) assert dataset.len() > 0 @@ -353,14 +382,18 @@ def test_multi_file_graphdataset(self): def test_save_external_links_graphdataset(self): n = 2 - with h5py.File("tests/data/hdf5/test.hdf5", 'r') as hdf5: + with h5py.File("tests/data/hdf5/test.hdf5", "r") as hdf5: original_ids = list(hdf5.keys()) - save_hdf5_keys("tests/data/hdf5/test.hdf5", original_ids[:n], "tests/data/hdf5/test_resized.hdf5") + save_hdf5_keys( + "tests/data/hdf5/test.hdf5", + original_ids[:n], + "tests/data/hdf5/test_resized.hdf5", + ) - with h5py.File("tests/data/hdf5/test_resized.hdf5", 'r') as hdf5: + with h5py.File("tests/data/hdf5/test_resized.hdf5", "r") as hdf5: new_ids = list(hdf5.keys()) - assert all(isinstance(hdf5.get(key, getlink=True), h5py.ExternalLink) for key in hdf5.keys()) + assert all(isinstance(hdf5.get(key, getlink=True), h5py.ExternalLink) for key in hdf5) assert len(new_ids) == n for new_id in new_ids: @@ -369,35 +402,40 @@ def test_save_external_links_graphdataset(self): def test_save_hard_links_graphdataset(self): n = 2 - with h5py.File("tests/data/hdf5/test.hdf5", 'r') as hdf5: + with h5py.File("tests/data/hdf5/test.hdf5", "r") as hdf5: original_ids = list(hdf5.keys()) - save_hdf5_keys("tests/data/hdf5/test.hdf5", original_ids[:n], "tests/data/hdf5/test_resized.hdf5", hardcopy = True) + save_hdf5_keys( + "tests/data/hdf5/test.hdf5", + original_ids[:n], + "tests/data/hdf5/test_resized.hdf5", + hardcopy=True, + ) - with h5py.File("tests/data/hdf5/test_resized.hdf5", 'r') as hdf5: + with h5py.File("tests/data/hdf5/test_resized.hdf5", "r") as hdf5: new_ids = list(hdf5.keys()) - assert all(isinstance(hdf5.get(key, getlink=True), h5py.HardLink) for key in hdf5.keys()) + assert all(isinstance(hdf5.get(key, getlink=True), h5py.HardLink) for key in hdf5) assert len(new_ids) == n for new_id in new_ids: assert new_id in original_ids def test_subset_graphdataset(self): - hdf5 = h5py.File("tests/data/hdf5/train.hdf5", 'r') # contains 44 datapoints + hdf5 = h5py.File("tests/data/hdf5/train.hdf5", "r") # contains 44 datapoints hdf5_keys = list(hdf5.keys()) n = 10 subset = hdf5_keys[:n] dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/train.hdf5", - subset = subset, - target = targets.BINARY + hdf5_path="tests/data/hdf5/train.hdf5", + subset=subset, + target=targets.BINARY, ) dataset_test = GraphDataset( - hdf5_path = "tests/data/hdf5/train.hdf5", - subset = subset, - train_source = dataset_train + hdf5_path="tests/data/hdf5/train.hdf5", + subset=subset, + train_source=dataset_train, ) assert n == len(dataset_train) @@ -406,50 +444,51 @@ def test_subset_graphdataset(self): hdf5.close() def test_target_transform_graphdataset(self): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/train.hdf5", - target = 'BA', # continuous values --> regression - task = targets.REGRESS, - target_transform = True + hdf5_path="tests/data/hdf5/train.hdf5", + target="BA", # continuous values --> regression + task=targets.REGRESS, + target_transform=True, ) for i in range(len(dataset)): - assert (0 <= dataset.get(i).y <= 1) + assert 0 <= dataset.get(i).y <= 1 def test_invalid_target_transform_graphdataset(self): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/train.hdf5", - target = targets.BINARY, # --> classification - target_transform = True # only for regression + hdf5_path="tests/data/hdf5/train.hdf5", + target=targets.BINARY, # --> classification + target_transform=True, # only for regression ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): dataset.get(0) def test_size_graphdataset(self): - hdf5_paths = ["tests/data/hdf5/train.hdf5", "tests/data/hdf5/valid.hdf5", "tests/data/hdf5/test.hdf5"] + hdf5_paths = [ + "tests/data/hdf5/train.hdf5", + "tests/data/hdf5/valid.hdf5", + "tests/data/hdf5/test.hdf5", + ] dataset = GraphDataset( - hdf5_path = hdf5_paths, - node_features = node_feats, - edge_features = [Efeat.DISTANCE], - target = targets.BINARY + hdf5_path=hdf5_paths, + node_features=node_feats, + edge_features=[Efeat.DISTANCE], + target=targets.BINARY, ) n = 0 for hdf5 in hdf5_paths: - with h5py.File(hdf5, 'r') as hdf5_r: + with h5py.File(hdf5, "r") as hdf5_r: n += len(hdf5_r.keys()) assert len(dataset) == n, f"total data points got was {len(dataset)}" def test_hdf5_to_pandas_graphdataset(self): - hdf5_path = "tests/data/hdf5/train.hdf5" dataset = GraphDataset( - hdf5_path = hdf5_path, - node_features = 'charge', - edge_features = ['distance', 'same_chain'], - target = 'binary' + hdf5_path=hdf5_path, + node_features="charge", + edge_features=["distance", "same_chain"], + target="binary", ) dataset.hdf5_to_pandas() cols = list(dataset.df.columns) @@ -458,20 +497,17 @@ def test_hdf5_to_pandas_graphdataset(self): # assert dataset and df shapes assert dataset.df.shape[0] == len(dataset) assert dataset.df.shape[1] == 5 - assert cols == ['binary', 'charge', 'distance', 'id', 'same_chain'] + assert cols == ["binary", "charge", "distance", "id", "same_chain"] # assert dataset and df values - with h5py.File(hdf5_path, 'r') as f5: - + with h5py.File(hdf5_path, "r") as f5: # getting nodes values with get() tensor_idx = 0 features_dict = {} for feat in dataset.node_features: - vals = f5[list(f5.keys())[0]][f"{Nfeat.NODE}/{feat}"][()] - if vals.ndim == 1: # features with only one channel - arr = [] - for entry_idx in range(len(dataset)): - arr.append(dataset.get(entry_idx).x[:, tensor_idx]) + vals = f5[next(iter(f5.keys()))][f"{Nfeat.NODE}/{feat}"][()] + if vals.ndim == 1: # features with only one channel + arr = [dataset.get(entry_idx).x[:, tensor_idx] for entry_idx in range(len(dataset))] arr = np.concatenate(arr) features_dict[feat] = arr tensor_idx += 1 @@ -482,7 +518,7 @@ def test_hdf5_to_pandas_graphdataset(self): arr.append(dataset.get(entry_idx).x[:, tensor_idx]) tensor_idx += 1 arr = np.concatenate(arr) - features_dict[feat + f'_{ch}'] = arr + features_dict[feat + f"_{ch}"] = arr for feat, values in features_dict.items(): assert np.allclose(values, np.concatenate(dataset.df[feat].values)) @@ -491,8 +527,8 @@ def test_hdf5_to_pandas_graphdataset(self): tensor_idx = 0 features_dict = {} for feat in dataset.edge_features: - vals = f5[list(f5.keys())[0]][f"{Efeat.EDGE}/{feat}"][()] - if vals.ndim == 1: # features with only one channel + vals = f5[next(iter(f5.keys()))][f"{Efeat.EDGE}/{feat}"][()] + if vals.ndim == 1: # features with only one channel arr = [] for entry_idx in range(len(dataset)): arr.append(dataset.get(entry_idx).edge_attr[:, tensor_idx]) @@ -506,7 +542,7 @@ def test_hdf5_to_pandas_graphdataset(self): arr.append(dataset.get(entry_idx).edge_attr[:, tensor_idx]) tensor_idx += 1 arr = np.concatenate(arr) - features_dict[feat + f'_{ch}'] = arr + features_dict[feat + f"_{ch}"] = arr for feat, values in features_dict.items(): # edge_attr contains stacked edges (doubled) so we test on mean and std @@ -514,53 +550,49 @@ def test_hdf5_to_pandas_graphdataset(self): assert np.float32(round(values.std(), 2)) == np.float32(round(np.concatenate(dataset.df[feat].values).std(), 2)) # assert dataset and df shapes in subset case - with h5py.File(hdf5_path, 'r') as f: + with h5py.File(hdf5_path, "r") as f: keys = list(f.keys()) dataset = GraphDataset( - hdf5_path = hdf5_path, - node_features = 'charge', - edge_features = ['distance', 'same_chain'], - target = 'binary', - subset = keys[2:] + hdf5_path=hdf5_path, + node_features="charge", + edge_features=["distance", "same_chain"], + target="binary", + subset=keys[2:], ) dataset.hdf5_to_pandas() assert dataset.df.shape[0] == len(keys[2:]) def test_save_hist_graphdataset(self): - output_directory = mkdtemp() fname = os.path.join(output_directory, "test.png") hdf5_path = "tests/data/hdf5/test.hdf5" - dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary' - ) + dataset = GraphDataset(hdf5_path=hdf5_path, target="binary") - with self.assertRaises(ValueError): - dataset.save_hist(['non existing feature'], fname = fname) + with pytest.raises(ValueError): + dataset.save_hist(["non existing feature"], fname=fname) - dataset.save_hist(['charge', 'binary'], fname = fname) + dataset.save_hist(["charge", "binary"], fname=fname) assert len(os.listdir(output_directory)) > 0 rmtree(output_directory) - def test_logic_train_graphdataset(self):# noqa: MC0001, pylint: disable=too-many-locals + def test_logic_train_graphdataset(self): hdf5_path = "tests/data/hdf5/train.hdf5" # without specifying features_transform in training set dataset_train = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary' + hdf5_path=hdf5_path, + target="binary", ) dataset_test = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary', - train_source = dataset_train + hdf5_path=hdf5_path, + target="binary", + train_source=dataset_train, ) # mean and devs should be None assert dataset_train.means == dataset_test.means @@ -569,42 +601,42 @@ def test_logic_train_graphdataset(self):# noqa: MC0001, pylint: disable=too-many assert dataset_train.devs is None # raise error if dataset_train is of the wrong type - with self.assertRaises(TypeError): - - dataset_train = GridDataset( - hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5", - target = 'binary' - ) + dataset_train = GridDataset( + hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", + target="binary", + ) + with pytest.raises(TypeError): GraphDataset( - hdf5_path = hdf5_path, - train_source = dataset_train, - target = 'binary', + hdf5_path=hdf5_path, + train_source=dataset_train, + target="binary", ) - def test_only_transform_graphdataset(self):# noqa: MC0001, pylint: disable=too-many-locals + def test_only_transform_graphdataset(self): # define a features_transform dict for only transformations, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) and a case with transform equals to None (sasa) hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = {'bsa': {'transform': lambda t: np.log(t+10)}, - 'electrostatic': {'transform': lambda t:np.cbrt(t)}, # pylint: disable=unnecessary-lambda - 'sasa': {'transform': None}, - 'hse': {'transform': lambda t: np.log(t+10)} - } + features_transform = { + "bsa": {"transform": lambda t: np.log(t + 10)}, + "electrostatic": {"transform": lambda t: np.cbrt(t)}, + "sasa": {"transform": None}, + "hse": {"transform": lambda t: np.log(t + 10)}, + } # dataset that has the transformations applied using features_transform dict transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary', + hdf5_path=hdf5_path, + features_transform=features_transform, + target="binary", ) # dataset with no transformations applied dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary', + hdf5_path=hdf5_path, + target="binary", ) # transformed features @@ -625,7 +657,7 @@ def test_only_transform_graphdataset(self):# noqa: MC0001, pylint: disable=too-m if orig_feat and (orig_feat in features_transform) and (orig_feat not in checked_features): checked_features.append(orig_feat) - transform = features_transform.get(orig_feat, {}).get('transform') + transform = features_transform.get(orig_feat, {}).get("transform") arr, _, _ = _compute_features_manually(hdf5_path, features_transform, orig_feat) if arr.ndim == 1: # checking that the mean and the std are the same in both the feature computed through @@ -634,56 +666,79 @@ def test_only_transform_graphdataset(self):# noqa: MC0001, pylint: disable=too-m assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(arr)) if transform: # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert not np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert not np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert not np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset, because # no transformation should be applied - assert np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: for i in range(arr.shape[1]): # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(arr[:, i])) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(arr[:, i])) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(arr[:, i]), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(arr[:, i]), + ) if transform: # check that the feature mean and std are different in transf_dataset and dataset assert not np.allclose( - np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) assert not np.allclose( - np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset, because # no transformation should be applied assert np.allclose( - np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) assert np.allclose( - np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) - assert (sorted(checked_features) == sorted(list(features_transform.keys()))) and (len(checked_features) == len(features_transform.keys())) + assert sorted(checked_features) == sorted(features_transform.keys()) + assert len(checked_features) == len(features_transform.keys()) - def test_only_transform_all_graphdataset(self):# noqa: MC0001, pylint: disable=too-many-locals + def test_only_transform_all_graphdataset(self): # define a features_transform dict for only transformations for `all` features hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = {'all': {'transform': lambda t: np.log(abs(t)+.01)}} + features_transform = {"all": {"transform": lambda t: np.log(abs(t) + 0.01)}} # dataset that has the transformations applied using features_transform dict transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary', + hdf5_path=hdf5_path, + features_transform=features_transform, + target="binary", ) # dataset with no transformations applied dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary', + hdf5_path=hdf5_path, + target="binary", ) # transformed features @@ -711,41 +766,61 @@ def test_only_transform_all_graphdataset(self):# noqa: MC0001, pylint: disable=t assert np.allclose(np.nanmean(transf_feat_value), np.nanmean(arr)) assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(arr)) # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert not np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert not np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert not np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: for i in range(arr.shape[1]): # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(arr[:, i])) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(arr[:, i])) - assert not np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(arr[:, i]), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(arr[:, i]), + ) + assert not np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + assert not np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) - assert (sorted(checked_features) == sorted(features)) and (len(checked_features) == len(features)) + assert sorted(checked_features) == sorted(features) + assert len(checked_features) == len(features) - def test_only_standardize_graphdataset(self): # pylint: disable=too-many-locals + def test_only_standardize_graphdataset(self): # define a features_transform dict for only standardization, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) and a case with standardize False (sasa) hdf5_path = "tests/data/hdf5/train.hdf5" features_transform = { - 'bsa': {'standardize': True}, - 'hse': {'standardize': True}, - 'electrostatic': {'standardize': True}, - 'sasa': {'standardize': False}} + "bsa": {"standardize": True}, + "hse": {"standardize": True}, + "electrostatic": {"standardize": True}, + "sasa": {"standardize": False}, + } transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary' + hdf5_path=hdf5_path, + features_transform=features_transform, + target="binary", ) dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary' + hdf5_path=hdf5_path, + target="binary", ) # standardized features @@ -766,11 +841,11 @@ def test_only_standardize_graphdataset(self): # pylint: disable=too-many-locals if orig_feat and (orig_feat in features_transform) and (orig_feat not in checked_features): checked_features.append(orig_feat) - standardize = features_transform.get(orig_feat, {}).get('standardize') + standardize = features_transform.get(orig_feat, {}).get("standardize") arr, mean, dev = _compute_features_manually(hdf5_path, features_transform, orig_feat) if standardize: # standardize manually - arr = (arr-mean)/dev + arr = (arr - mean) / dev if arr.ndim == 1: # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually @@ -778,52 +853,78 @@ def test_only_standardize_graphdataset(self): # pylint: disable=too-many-locals assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(arr)) if standardize: # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert not np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert not np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert not np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset, because # no transformation should be applied - assert np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: for i in range(arr.shape[1]): # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(arr[:, i])) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(arr[:, i])) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(arr[:, i]), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(arr[:, i]), + ) if standardize: # check that the feature mean and std are different in transf_dataset and dataset assert not np.allclose( - np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) assert not np.allclose( - np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset, because # no standardization should be applied - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(features_dict.get(orig_feat + f'_{i}'))) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) - assert (sorted(checked_features) == sorted(list(features_transform.keys()))) and (len(checked_features) == len(features_transform.keys())) + assert sorted(checked_features) == sorted(features_transform.keys()) + assert len(checked_features) == len(features_transform.keys()) - def test_only_standardize_all_graphdataset(self): # pylint: disable=too-many-locals + def test_only_standardize_all_graphdataset(self): # define a features_transform dict for only standardization for `all` features hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = { - 'all': {'standardize': True}} + features_transform = {"all": {"standardize": True}} # dataset that has the standardization applied using features_transform dict transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary' + hdf5_path=hdf5_path, + features_transform=features_transform, + target="binary", ) # dataset with no standardization applied dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary' + hdf5_path=hdf5_path, + target="binary", ) # standardized features @@ -846,50 +947,66 @@ def test_only_standardize_all_graphdataset(self): # pylint: disable=too-many-loc checked_features.append(orig_feat) arr, mean, dev = _compute_features_manually(hdf5_path, features_transform, orig_feat) # standardize manually - arr = (arr-mean)/dev + arr = (arr - mean) / dev if arr.ndim == 1: # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually assert np.allclose(np.nanmean(transf_feat_value), np.nanmean(arr)) assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(arr)) # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert not np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert not np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert not np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: for i in range(arr.shape[1]): # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(arr[:,i])) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(arr[:,i])) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(arr[:, i]), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(arr[:, i]), + ) # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(features_dict.get(orig_feat + f'_{i}'))) - assert not np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(features_dict.get(orig_feat + f'_{i}'))) - - assert (sorted(checked_features) == sorted(features)) and (len(checked_features) == len(features)) - - def test_transform_standardize_graphdataset(self):# noqa: MC0001, pylint: disable=too-many-locals + assert not np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) + assert not np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) + + assert sorted(checked_features) == sorted(features) + assert len(checked_features) == len(features) + + def test_transform_standardize_graphdataset(self): # define a features_transform dict for both transformations and standardization, # including node (bsa) and edge features (electrostatic), # a multi-channel feature (hse) hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = {'bsa': {'transform': lambda t: np.log(t+10), 'standardize': True}, - 'electrostatic': {'transform': lambda t:np.cbrt(t), 'standardize': True}, # pylint: disable=unnecessary-lambda - 'sasa': {'transform': None, 'standardize': False}, - 'hse': {'transform': lambda t: np.log(t+10), 'standardize': False} - } + features_transform = { + "bsa": {"transform": lambda t: np.log(t + 10), "standardize": True}, + "electrostatic": {"transform": lambda t: np.cbrt(t), "standardize": True}, + "sasa": {"transform": None, "standardize": False}, + "hse": {"transform": lambda t: np.log(t + 10), "standardize": False}, + } # dataset that has the transformations applied using features_transform dict - transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary' - ) + transf_dataset = GraphDataset(hdf5_path=hdf5_path, features_transform=features_transform, target="binary") # dataset with no transformations applied dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary', + hdf5_path=hdf5_path, + target="binary", ) # transformed features @@ -910,12 +1027,12 @@ def test_transform_standardize_graphdataset(self):# noqa: MC0001, pylint: disabl if orig_feat and (orig_feat in features_transform) and (orig_feat not in checked_features): checked_features.append(orig_feat) - transform = features_transform.get(orig_feat, {}).get('transform') - standardize = features_transform.get(orig_feat, {}).get('standardize') + transform = features_transform.get(orig_feat, {}).get("transform") + standardize = features_transform.get(orig_feat, {}).get("standardize") arr, mean, dev = _compute_features_manually(hdf5_path, features_transform, orig_feat) if standardize: # standardize manually - arr = (arr-mean)/dev + arr = (arr - mean) / dev if arr.ndim == 1: # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually @@ -923,53 +1040,75 @@ def test_transform_standardize_graphdataset(self):# noqa: MC0001, pylint: disabl assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(arr)) if transform or standardize: # check that the feature mean and std are different in transf_dataset and dataset - assert not np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert not np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert not np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert not np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset - assert np.allclose(np.nanmean(transf_feat_value), np.nanmean(features_dict.get(transf_feat_key))) - assert np.allclose(np.nanstd(transf_feat_value), np.nanstd(features_dict.get(transf_feat_key))) + assert np.allclose( + np.nanmean(transf_feat_value), + np.nanmean(features_dict.get(transf_feat_key)), + ) + assert np.allclose( + np.nanstd(transf_feat_value), + np.nanstd(features_dict.get(transf_feat_key)), + ) else: for i in range(arr.shape[1]): # checking that the mean and the std are the same in both the feature computed through # the get method and the feature computed manually - assert np.allclose(np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), np.nanmean(arr[:, i])) - assert np.allclose(np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), np.nanstd(arr[:, i])) + assert np.allclose( + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(arr[:, i]), + ) + assert np.allclose( + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(arr[:, i]), + ) if transform or standardize: # check that the feature mean and std are different in transf_dataset and dataset assert not np.allclose( - np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) assert not np.allclose( - np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) else: # check that the feature mean and std are the same in transf_dataset and dataset, because assert np.allclose( - np.nanmean(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanmean(features_dict.get(orig_feat + f'_{i}'))) + np.nanmean(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanmean(features_dict.get(orig_feat + f"_{i}")), + ) assert np.allclose( - np.nanstd(transf_features_dict.get(orig_feat + f'_{i}')), - np.nanstd(features_dict.get(orig_feat + f'_{i}'))) + np.nanstd(transf_features_dict.get(orig_feat + f"_{i}")), + np.nanstd(features_dict.get(orig_feat + f"_{i}")), + ) - assert (sorted(checked_features) == sorted(list(features_transform.keys()))) and (len(checked_features) == len(features_transform.keys())) + assert sorted(checked_features) == sorted(features_transform.keys()) + assert len(checked_features) == len(features_transform.keys()) def test_features_transform_logic_graphdataset(self): - hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = {'all': {'transform': lambda t:np.cbrt(t), 'standardize': True}} # pylint: disable=unnecessary-lambda - other_feature_transform = {'all': {'transform': None, 'standardize': False}} + features_transform = {"all": {"transform": lambda t: np.cbrt(t), "standardize": True}} + other_feature_transform = {"all": {"transform": None, "standardize": False}} dataset_train = GraphDataset( - hdf5_path = hdf5_path, - features_transform = features_transform, - target = 'binary' + hdf5_path=hdf5_path, + features_transform=features_transform, + target="binary", ) dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = dataset_train, - target = 'binary' + hdf5_path=hdf5_path, + train_source=dataset_train, + target="binary", ) # features_transform in the test should be the same as in the train @@ -981,10 +1120,10 @@ def test_features_transform_logic_graphdataset(self): assert dataset_train.devs is not None dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = dataset_train, - features_transform = other_feature_transform, - target = 'binary' + hdf5_path=hdf5_path, + train_source=dataset_train, + features_transform=other_feature_transform, + target="binary", ) # features_transform setted in the testset should be ignored @@ -993,101 +1132,107 @@ def test_features_transform_logic_graphdataset(self): assert dataset_train.devs == dataset_test.devs def test_invalid_value_features_transform(self): - hdf5_path = "tests/data/hdf5/train.hdf5" - features_transform = {'all': {'transform': lambda t: np.log(t+10), 'standardize': True}} + features_transform = {"all": {"transform": lambda t: np.log(t + 10), "standardize": True}} transf_dataset = GraphDataset( - hdf5_path = hdf5_path, - target = 'binary', - features_transform = features_transform + hdf5_path=hdf5_path, + target="binary", + features_transform=features_transform, ) - with pytest.raises(ValueError): - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'divide by zero encountered in divide') + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"divide by zero encountered in divide") + with pytest.raises(ValueError): _compute_features_with_get(hdf5_path, transf_dataset) def test_inherit_info_dataset_train_graphdataset(self): hdf5_path = "tests/data/hdf5/train.hdf5" - feature_transform = {'all': {'transform': None, 'standardize': True}} + feature_transform = {"all": {"transform": None, "standardize": True}} dataset_train = GraphDataset( - hdf5_path = hdf5_path, - node_features = ['bsa', 'hb_acceptors', 'hb_donors'], - edge_features = ['covalent', 'distance'], - features_transform = feature_transform, - target = 'binary', - target_transform = False, - task = 'classif', - classes = None + hdf5_path=hdf5_path, + node_features=["bsa", "hb_acceptors", "hb_donors"], + edge_features=["covalent", "distance"], + features_transform=feature_transform, + target="binary", + target_transform=False, + task="classif", + classes=None, ) dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = dataset_train, + hdf5_path=hdf5_path, + train_source=dataset_train, ) - _check_inherited_params(dataset_test.inherited_params, dataset_train, dataset_test) + _check_inherited_params( + dataset_test.inherited_params, + dataset_train, + dataset_test, + ) dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = dataset_train, - node_features = "all", - edge_features = "all", - features_transform = None, - target = 'BA', - target_transform = True, - task = "regress", - classes = None + hdf5_path=hdf5_path, + train_source=dataset_train, + node_features="all", + edge_features="all", + features_transform=None, + target="BA", + target_transform=True, + task="regress", + classes=None, ) - _check_inherited_params(dataset_test.inherited_params, dataset_train, dataset_test) + _check_inherited_params( + dataset_test.inherited_params, + dataset_train, + dataset_test, + ) def test_inherit_info_pretrained_model_graphdataset(self): - hdf5_path = "tests/data/hdf5/test.hdf5" pretrained_model = "tests/data/pretrained/testing_graph_model.pth.tar" dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = pretrained_model + hdf5_path=hdf5_path, + train_source=pretrained_model, ) - data = torch.load(pretrained_model, map_location=torch.device('cpu')) + data = torch.load(pretrained_model, map_location=torch.device("cpu")) if data["features_transform"]: - for _, key in data["features_transform"].items(): - if key['transform'] is None: + for key in data["features_transform"].values(): + if key["transform"] is None: continue - key['transform'] = eval(key['transform']) # pylint: disable=eval-used + key["transform"] = eval(key["transform"]) # noqa: S307, PGH001 (suspicious-eval-usage) dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: - if param == 'features_transform': + if param == "features_transform": for item, key in data[param].items(): - assert key['transform'].__code__.co_code == dataset_test_vars[param][item]['transform'].__code__.co_code - assert key['standardize'] == dataset_test_vars[param][item]['standardize'] + assert key["transform"].__code__.co_code == dataset_test_vars[param][item]["transform"].__code__.co_code + assert key["standardize"] == dataset_test_vars[param][item]["standardize"] else: assert dataset_test_vars[param] == data[param] dataset_test = GraphDataset( - hdf5_path = hdf5_path, - train_source = pretrained_model, - node_features = "all", - edge_features = "all", - features_transform = None, - target = 'BA', - target_transform = True, - task = "regress", - classes = None + hdf5_path=hdf5_path, + train_source=pretrained_model, + node_features="all", + edge_features="all", + features_transform=None, + target="BA", + target_transform=True, + task="regress", + classes=None, ) # node_features, edge_features, feature_transform, target, target_transform, task, and classes # in the test should be inherited from the pre-trained model dataset_test_vars = vars(dataset_test) for param in dataset_test.inherited_params: - if param == 'features_transform': + if param == "features_transform": for item, key in data[param].items(): - assert key['transform'].__code__.co_code == dataset_test_vars[param][item]['transform'].__code__.co_code - assert key['standardize'] == dataset_test_vars[param][item]['standardize'] + assert key["transform"].__code__.co_code == dataset_test_vars[param][item]["transform"].__code__.co_code + assert key["standardize"] == dataset_test_vars[param][item]["standardize"] else: assert dataset_test_vars[param] == data[param] @@ -1097,72 +1242,68 @@ def test_no_target_dataset_graphdataset(self): pretrained_model = "tests/data/pretrained/testing_graph_model.pth.tar" dataset = GraphDataset( - hdf5_path = hdf5_no_target, - train_source = pretrained_model + hdf5_path=hdf5_no_target, + train_source=pretrained_model, ) assert dataset.target is not None assert dataset.get(0).y is None # no target set, training mode - with self.assertRaises(ValueError): - dataset = GraphDataset( - hdf5_path = hdf5_no_target - ) + with pytest.raises(ValueError): + dataset = GraphDataset(hdf5_path=hdf5_no_target) # target set, but not present in the file - with self.assertRaises(ValueError): + with pytest.raises(ValueError): dataset = GraphDataset( - hdf5_path = hdf5_target, - target = 'CAPRI' + hdf5_path=hdf5_target, + target="CAPRI", ) def test_incompatible_dataset_train_type(self): dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - edge_features = [Efeat.DISTANCE, Efeat.COVALENT], - target = targets.BINARY + hdf5_path="tests/data/hdf5/test.hdf5", + edge_features=[Efeat.DISTANCE, Efeat.COVALENT], + target=targets.BINARY, ) # Raise error when val dataset don't have the same data type as train dataset. with pytest.raises(TypeError): GridDataset( - hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5", - train_source = dataset_train + hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", + train_source=dataset_train, ) def test_invalid_pretrained_model_path(self): - hdf5_graph = "tests/data/hdf5/test.hdf5" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): GraphDataset( - hdf5_path = hdf5_graph, - train_source = hdf5_graph + hdf5_path=hdf5_graph, + train_source=hdf5_graph, ) hdf5_grid = "tests/data/hdf5/1ATN_ppi.hdf5" - with self.assertRaises(ValueError): + with pytest.raises(ValueError): GridDataset( - hdf5_path = hdf5_grid, - train_source = hdf5_grid + hdf5_path=hdf5_grid, + train_source=hdf5_grid, ) def test_invalid_pretrained_model_data_type(self): - hdf5_graph = "tests/data/hdf5/test.hdf5" pretrained_grid_model = "tests/data/pretrained/testing_grid_model.pth.tar" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): GraphDataset( - hdf5_path = hdf5_graph, - train_source = pretrained_grid_model + hdf5_path=hdf5_graph, + train_source=pretrained_grid_model, ) hdf5_grid = "tests/data/hdf5/1ATN_ppi.hdf5" pretrained_graph_model = "tests/data/pretrained/testing_graph_model.pth.tar" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): GridDataset( - hdf5_path = hdf5_grid, - train_source = pretrained_graph_model + hdf5_path=hdf5_grid, + train_source=pretrained_graph_model, ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 579d292bc..1fee2ae35 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -21,26 +21,25 @@ from deeprank2.utils.exporters import HDF5OutputExporter from deeprank2.utils.grid import GridSettings, MapMethod -pdb_path = str("tests/data/pdb/3C8P/3C8P.pdb") -ref_path = str("tests/data/ref/3C8P/3C8P.pdb") -pssm_path1 = str("tests/data/pssm/3C8P/3C8P.A.pdb.pssm") -pssm_path2 = str("tests/data/pssm/3C8P/3C8P.B.pdb.pssm") +pdb_path = "tests/data/pdb/3C8P/3C8P.pdb" +ref_path = "tests/data/ref/3C8P/3C8P.pdb" +pssm_path1 = "tests/data/pssm/3C8P/3C8P.A.pdb.pssm" +pssm_path2 = "tests/data/pssm/3C8P/3C8P.B.pdb.pssm" chain_id1 = "A" chain_id2 = "B" count_queries = 3 -def test_cnn(): # pylint: disable=too-many-locals +def test_cnn(): """ Tests processing several PDB files into their features representation HDF5 file. Then uses HDF5 generated files to train and test a CnnRegression network. """ - hdf5_directory = mkdtemp() output_directory = mkdtemp() - model_path = output_directory + 'test.pth.tar' + model_path = output_directory + "test.pth.tar" prefix = os.path.join(hdf5_directory, "test-queries-process") @@ -51,16 +50,18 @@ def test_cnn(): # pylint: disable=too-many-locals for _ in range(count_queries): query = ProteinProteinInterfaceQuery( pdb_path=pdb_path, - resolution='residue', - chain_ids=[chain_id1,chain_id2], + resolution="residue", + chain_ids=[chain_id1, chain_id2], pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, - targets = all_targets + targets=all_targets, ) queries.add(query) - hdf5_paths = queries.process(prefix = prefix, - grid_settings=GridSettings([20, 20, 20], [20.0, 20.0, 20.0]), - grid_map_method=MapMethod.GAUSSIAN) + hdf5_paths = queries.process( + prefix=prefix, + grid_settings=GridSettings([20, 20, 20], [20.0, 20.0, 20.0]), + grid_map_method=MapMethod.GAUSSIAN, + ) assert len(hdf5_paths) > 0 graph_names = [] @@ -75,19 +76,19 @@ def test_cnn(): # pylint: disable=too-many-locals features = [Nfeat.RESTYPE, Efeat.DISTANCE] dataset_train = GridDataset( - hdf5_path = hdf5_paths, - features = features, - target = targets.BINARY + hdf5_path=hdf5_paths, + features=features, + target=targets.BINARY, ) dataset_val = GridDataset( - hdf5_path = hdf5_paths, - train_source = dataset_train, + hdf5_path=hdf5_paths, + train_source=dataset_train, ) dataset_test = GridDataset( - hdf5_path = hdf5_paths, - train_source = dataset_train, + hdf5_path=hdf5_paths, + train_source=dataset_train, ) output_exporters = [HDF5OutputExporter(output_directory)] @@ -97,29 +98,40 @@ def test_cnn(): # pylint: disable=too-many-locals dataset_train, dataset_val, dataset_test, - output_exporters=output_exporters + output_exporters=output_exporters, ) with warnings.catch_warnings(record=UserWarning): - trainer.train(nepoch=3, batch_size=64, validate=True, best_model=False, filename=model_path) + trainer.train( + nepoch=3, + batch_size=64, + validate=True, + best_model=False, + filename=model_path, + ) - Trainer(CnnClassification, dataset_train, dataset_val, dataset_test, pretrained_model=model_path) + Trainer( + CnnClassification, + dataset_train, + dataset_val, + dataset_test, + pretrained_model=model_path, + ) assert len(os.listdir(output_directory)) > 0 finally: rmtree(hdf5_directory) rmtree(output_directory) -def test_gnn(): # pylint: disable=too-many-locals - """ - Tests processing several PDB files into their features representation HDF5 file. + +def test_gnn(): + """Tests processing several PDB files into their features representation HDF5 file. Then uses HDF5 generated files to train and test a GINet network. """ - hdf5_directory = mkdtemp() output_directory = mkdtemp() - model_path = output_directory + 'test.pth.tar' + model_path = output_directory + "test.pth.tar" prefix = os.path.join(hdf5_directory, "test-queries-process") @@ -130,14 +142,14 @@ def test_gnn(): # pylint: disable=too-many-locals for _ in range(count_queries): query = ProteinProteinInterfaceQuery( pdb_path=pdb_path, - resolution='residue', - chain_ids=[chain_id1,chain_id2], + resolution="residue", + chain_ids=[chain_id1, chain_id2], pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, - targets = all_targets + targets=all_targets, ) queries.add(query) - hdf5_paths = queries.process(prefix = prefix) + hdf5_paths = queries.process(prefix=prefix) assert len(hdf5_paths) > 0 graph_names = [] @@ -152,25 +164,24 @@ def test_gnn(): # pylint: disable=too-many-locals node_features = [Nfeat.RESTYPE] edge_features = [Efeat.DISTANCE] - dataset_train = GraphDataset( - hdf5_path = hdf5_paths, - node_features = node_features, - edge_features = edge_features, - clustering_method = "mcl", - target = targets.BINARY + hdf5_path=hdf5_paths, + node_features=node_features, + edge_features=edge_features, + clustering_method="mcl", + target=targets.BINARY, ) dataset_val = GraphDataset( - hdf5_path = hdf5_paths, - train_source = dataset_train, - clustering_method = "mcl" + hdf5_path=hdf5_paths, + train_source=dataset_train, + clustering_method="mcl", ) dataset_test = GraphDataset( - hdf5_path = hdf5_paths, - train_source = dataset_train, - clustering_method = "mcl" + hdf5_path=hdf5_paths, + train_source=dataset_train, + clustering_method="mcl", ) output_exporters = [HDF5OutputExporter(output_directory)] @@ -180,13 +191,25 @@ def test_gnn(): # pylint: disable=too-many-locals dataset_train, dataset_val, dataset_test, - output_exporters=output_exporters + output_exporters=output_exporters, ) with warnings.catch_warnings(record=UserWarning): - trainer.train(nepoch=3, batch_size=64, validate=True, best_model=False, filename=model_path) + trainer.train( + nepoch=3, + batch_size=64, + validate=True, + best_model=False, + filename=model_path, + ) - Trainer(GINet, dataset_train, dataset_val, dataset_test, pretrained_model=model_path) + Trainer( + GINet, + dataset_train, + dataset_val, + dataset_test, + pretrained_model=model_path, + ) assert len(os.listdir(output_directory)) > 0 @@ -194,7 +217,8 @@ def test_gnn(): # pylint: disable=too-many-locals rmtree(hdf5_directory) rmtree(output_directory) -@pytest.fixture(scope='session') + +@pytest.fixture(scope="session") def hdf5_files_for_nan(tmpdir_factory): # For testing cases in which the loss function is nan for the validation and/or for # the training sets. It doesn't matter if the dataset is a GraphDataset or a GridDataset, @@ -203,8 +227,8 @@ def hdf5_files_for_nan(tmpdir_factory): pdb_paths = [ "tests/data/pdb/3C8P/3C8P.pdb", "tests/data/pdb/1A0Z/1A0Z.pdb", - "tests/data/pdb/1ATN/1ATN_1w.pdb" - ] + "tests/data/pdb/1ATN/1ATN_1w.pdb", + ] chain_id1 = "A" chain_id2 = "B" targets_values = [0, 1, 1] @@ -214,56 +238,59 @@ def hdf5_files_for_nan(tmpdir_factory): for idx, pdb_path in enumerate(pdb_paths): query = ProteinProteinInterfaceQuery( pdb_path=pdb_path, - resolution='residue', - chain_ids=[chain_id1,chain_id2], - targets = {targets.BINARY: targets_values[idx]}, + resolution="residue", + chain_ids=[chain_id1, chain_id2], + targets={targets.BINARY: targets_values[idx]}, # A very low radius and edge length helps for not making the network to learn influence_radius=3, - max_edge_length=3 + max_edge_length=3, ) queries.add(query) - hdf5_paths = queries.process(prefix = prefix) - return hdf5_paths + return queries.process(prefix=prefix) -@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) + +@pytest.mark.parametrize("validate, best_model", [(True, True), (False, True), (False, False), (True, False)]) # noqa: PT006 (pytest-parametrize-names-wrong-type) def test_nan_loss_cases(validate, best_model, hdf5_files_for_nan): mols = [] for fname in hdf5_files_for_nan: - with h5py.File(fname, 'r') as hdf5: - for mol in hdf5.keys(): - mols.append(mol) + with h5py.File(fname, "r") as hdf5: + for mol in hdf5: + mols.append(mol) # noqa: PERF402 (manual-list-copy) dataset_train = GraphDataset( - hdf5_path = hdf5_files_for_nan, - subset = mols[1:], - target = targets.BINARY, - task = targets.CLASSIF - ) + hdf5_path=hdf5_files_for_nan, + subset=mols[1:], + target=targets.BINARY, + task=targets.CLASSIF, + ) dataset_valid = GraphDataset( - hdf5_path = hdf5_files_for_nan, - subset = [mols[0]], - train_source=dataset_train - ) + hdf5_path=hdf5_files_for_nan, + subset=[mols[0]], + train_source=dataset_train, + ) - trainer = Trainer( - NaiveNetwork, - dataset_train, - dataset_valid) + trainer = Trainer(NaiveNetwork, dataset_train, dataset_valid) optimizer = torch.optim.SGD lr = 10000 weight_decay = 10000 trainer.configure_optimizers(optimizer, lr, weight_decay=weight_decay) - w_msg = "A model has been saved but the validation and/or the training losses were NaN;" + \ - "\n\ttry to increase the cutoff distance during the data processing or the number of data points " + \ - "during the training." + w_msg = ( + "A model has been saved but the validation and/or the training losses were NaN;\n\t" + "try to increase the cutoff distance during the data processing or the number of data points " + "during the training." + ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning) trainer.train( - nepoch=5, batch_size=1, validate=validate, - best_model=best_model, filename='test.pth.tar') + nepoch=5, + batch_size=1, + validate=validate, + best_model=best_model, + filename="test.pth.tar", + ) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert w_msg in str(w[-1].message) diff --git a/tests/test_query.py b/tests/test_query.py index 9a3a7ffb8..73e9f2834 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,8 +12,11 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets from deeprank2.features import components, conservation, contact, surfacearea -from deeprank2.query import (ProteinProteinInterfaceQuery, QueryCollection, - SingleResidueVariantQuery) +from deeprank2.query import ( + ProteinProteinInterfaceQuery, + QueryCollection, + SingleResidueVariantQuery, +) from deeprank2.utils.graph import Graph from deeprank2.utils.grid import GridSettings, MapMethod @@ -23,7 +26,6 @@ def _check_graph_makes_sense( node_feature_names: list[str], edge_feature_names: list[str], ): - assert len(g.nodes) > 0, "no nodes" assert Nfeat.POSITION in g.nodes[0].features @@ -44,28 +46,18 @@ def _check_graph_makes_sense( g.write_to_hdf5(tmp_path) with h5py.File(tmp_path, "r") as f5: - grp = f5[list(f5.keys())[0]] + grp = f5[next(iter(f5.keys()))] for feature_name in node_feature_names: - assert ( - grp[f"{Nfeat.NODE}/{feature_name}"][()].size > 0 - ), f"no {feature_name} feature" + assert grp[f"{Nfeat.NODE}/{feature_name}"][()].size > 0, f"no {feature_name} feature" - assert ( - len( - np.nonzero( - grp[f"{Nfeat.NODE}/{feature_name}"][()] - ) - ) - > 0 - ), f"{feature_name}: all zero" + assert len(np.nonzero(grp[f"{Nfeat.NODE}/{feature_name}"][()])) > 0, f"{feature_name}: all zero" assert grp[f"{Efeat.EDGE}/{Efeat.INDEX}"][()].shape[1] == 2, "wrong edge index shape" assert grp[f"{Efeat.EDGE}/{Efeat.INDEX}"].shape[0] > 0, "no edge indices" for feature_name in edge_feature_names: assert ( - grp[f"{Efeat.EDGE}/{feature_name}"][()].shape[0] - == grp[f"{Efeat.EDGE}/{Efeat.INDEX}"].shape[0] + grp[f"{Efeat.EDGE}/{feature_name}"][()].shape[0] == grp[f"{Efeat.EDGE}/{Efeat.INDEX}"].shape[0] ), f"not enough edge {feature_name} feature values" count_edges_hdf5 = grp[f"{Efeat.EDGE}/{Efeat.INDEX}"].shape[0] @@ -77,14 +69,10 @@ def _check_graph_makes_sense( # expecting twice as many edges, because torch is directional count_edges_torch = torch_data_entry.edge_index.shape[1] - assert ( - count_edges_torch == 2 * count_edges_hdf5 - ), f"got {count_edges_torch} edges in output data, hdf5 has {count_edges_hdf5}" + assert count_edges_torch == 2 * count_edges_hdf5, f"got {count_edges_torch} edges in output data, hdf5 has {count_edges_hdf5}" count_edge_features_torch = torch_data_entry.edge_attr.shape[0] - assert ( - count_edge_features_torch == count_edges_torch - ), f"got {count_edge_features_torch} edge feature sets, but {count_edges_torch} edge indices" + assert count_edge_features_torch == count_edges_torch, f"got {count_edge_features_torch} edge feature sets, but {count_edges_torch} edge indices" finally: os.remove(tmp_path) @@ -287,52 +275,60 @@ def test_res_ppi(): def test_augmentation(): qc = QueryCollection() - qc.add(ProteinProteinInterfaceQuery( - pdb_path="tests/data/pdb/3C8P/3C8P.pdb", - resolution="residue", - chain_ids=["A", "B"], - pssm_paths={ - "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", - "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", - }, - targets={targets.BINARY: 0}, - )) + qc.add( + ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ + "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", + "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", + }, + targets={targets.BINARY: 0}, + ) + ) - qc.add(ProteinProteinInterfaceQuery( - pdb_path="tests/data/pdb/3C8P/3C8P.pdb", - resolution="atom", - chain_ids=["A", "B"], - pssm_paths={ - "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", - "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", - }, - targets={targets.BINARY: 0}, - )) + qc.add( + ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="atom", + chain_ids=["A", "B"], + pssm_paths={ + "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", + "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", + }, + targets={targets.BINARY: 0}, + ) + ) - qc.add(SingleResidueVariantQuery( - pdb_path="tests/data/pdb/101M/101M.pdb", - resolution="residue", - chain_ids="A", - variant_residue_number=25, - insertion_code=None, - wildtype_amino_acid=aa.glycine, - variant_amino_acid=aa.alanine, - pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, - targets={targets.BINARY: 0}, - )) + qc.add( + SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="residue", + chain_ids="A", + variant_residue_number=25, + insertion_code=None, + wildtype_amino_acid=aa.glycine, + variant_amino_acid=aa.alanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + targets={targets.BINARY: 0}, + ) + ) - qc.add(SingleResidueVariantQuery( - pdb_path="tests/data/pdb/101M/101M.pdb", - resolution="atom", - chain_ids="A", - variant_residue_number=27, - insertion_code=None, - wildtype_amino_acid=aa.asparagine, - variant_amino_acid=aa.phenylalanine, - pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, - targets={targets.BINARY: 0}, - influence_radius=3.0, - )) + qc.add( + SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=27, + insertion_code=None, + wildtype_amino_acid=aa.asparagine, + variant_amino_acid=aa.phenylalanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + targets={targets.BINARY: 0}, + influence_radius=3.0, + ) + ) augmentation_count = 3 grid_settings = GridSettings([20, 20, 20], [20.0, 20.0, 20.0]) @@ -340,20 +336,22 @@ def test_augmentation(): tmp_dir = mkdtemp() try: - qc.process(f"{tmp_dir}/qc", - grid_settings=grid_settings, - grid_map_method=MapMethod.GAUSSIAN, - grid_augmentation_count=augmentation_count) + qc.process( + f"{tmp_dir}/qc", + grid_settings=grid_settings, + grid_map_method=MapMethod.GAUSSIAN, + grid_augmentation_count=augmentation_count, + ) hdf5_path = f"{tmp_dir}/qc.hdf5" assert os.path.isfile(hdf5_path) - with h5py.File(hdf5_path, 'r') as f5: + with h5py.File(hdf5_path, "r") as f5: entry_names = list(f5.keys()) assert len(entry_names) == expected_entry_count, f"Found {len(entry_names)} entries, expected {expected_entry_count}" - dataset = GridDataset(hdf5_path, target = 'binary') + dataset = GridDataset(hdf5_path, target="binary") assert len(dataset) == expected_entry_count, f"Found {len(dataset)} data points, expected {expected_entry_count}" finally: @@ -425,6 +423,7 @@ def test_no_pssm_provided(): with pytest.raises(ValueError): _ = q_empty_dict.build([conservation]) + with pytest.raises(ValueError): _ = q_not_provided.build([conservation]) # no error if conservation module is not used @@ -446,16 +445,17 @@ def test_incorrect_pssm_provided(): # missing file q_missing = ProteinProteinInterfaceQuery( - pdb_path="tests/data/pdb/3C8P/3C8P.pdb", - resolution="residue", - chain_ids=["A", "B"], - pssm_paths={ + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", }, ) with pytest.raises(FileNotFoundError): _ = q_non_existing.build([conservation]) + with pytest.raises(FileNotFoundError): _ = q_missing.build([conservation]) # no error if conservation module is not used @@ -465,23 +465,23 @@ def test_incorrect_pssm_provided(): def test_variant_query_multiple_chains(): q = SingleResidueVariantQuery( - pdb_path = "tests/data/pdb/2g98/pdb2g98.pdb", - resolution = "atom", - chain_ids = "A", - variant_residue_number = 14, - insertion_code = None, - wildtype_amino_acid = aa.arginine, - variant_amino_acid = aa.cysteine, - pssm_paths = {"A": "tests/data/pssm/2g98/2g98.A.pdb.pssm"}, - targets = {targets.BINARY: 1}, - influence_radius = 10.0, - max_edge_length = 4.5, + pdb_path="tests/data/pdb/2g98/pdb2g98.pdb", + resolution="atom", + chain_ids="A", + variant_residue_number=14, + insertion_code=None, + wildtype_amino_acid=aa.arginine, + variant_amino_acid=aa.cysteine, + pssm_paths={"A": "tests/data/pssm/2g98/2g98.A.pdb.pssm"}, + targets={targets.BINARY: 1}, + influence_radius=10.0, + max_edge_length=4.5, ) # at radius 10, chain B is included in graph # no error without conservation module graph = q.build(components) - assert 'B' in graph.get_all_chains() + assert "B" in graph.get_all_chains() # if we rebuild the graph with conservation module it should fail with pytest.raises(FileNotFoundError): _ = q.build(conservation) @@ -489,4 +489,4 @@ def test_variant_query_multiple_chains(): # at radius 7, chain B is not included in graph q.influence_radius = 7.0 graph = q.build(conservation) - assert 'B' not in graph.get_all_chains() + assert "B" not in graph.get_all_chains() diff --git a/tests/test_querycollection.py b/tests/test_querycollection.py index 2b713e4ba..94077731f 100644 --- a/tests/test_querycollection.py +++ b/tests/test_querycollection.py @@ -11,15 +11,14 @@ from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain.aminoacidlist import alanine, phenylalanine from deeprank2.features import components, contact, surfacearea -from deeprank2.query import (ProteinProteinInterfaceQuery, Query, - QueryCollection, SingleResidueVariantQuery) +from deeprank2.query import ProteinProteinInterfaceQuery, Query, QueryCollection, SingleResidueVariantQuery from deeprank2.tools.target import compute_ppi_scores -def _querycollection_tester( # pylint: disable=dangerous-default-value +def _querycollection_tester( query_type: str, n_queries: int = 3, - feature_modules: ModuleType | list[ModuleType] = [components, contact], + feature_modules: ModuleType | list[ModuleType] = [components, contact], # noqa: B006 (unsafe default value) cpu_count: int = 1, combine_output: bool = True, ): @@ -36,25 +35,31 @@ def _querycollection_tester( # pylint: disable=dangerous-default-value combine_output (bool): boolean for combining the hdf5 files generated by the processes. By default, the hdf5 files generated are combined into one, and then deleted. """ - - if query_type == 'ppi': - queries = [ProteinProteinInterfaceQuery( - pdb_path="tests/data/pdb/3C8P/3C8P.pdb", - resolution="residue", - chain_ids=["A","B"], - pssm_paths={"A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm"}, - )] * n_queries - elif query_type == 'srv': - queries = [SingleResidueVariantQuery( - pdb_path="tests/data/pdb/101M/101M.pdb", - resolution="residue", - chain_ids="A", - variant_residue_number=None, # placeholder - insertion_code=None, - wildtype_amino_acid=alanine, - variant_amino_acid=phenylalanine, - pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, - )] * n_queries + if query_type == "ppi": + queries = [ + ProteinProteinInterfaceQuery( + pdb_path="tests/data/pdb/3C8P/3C8P.pdb", + resolution="residue", + chain_ids=["A", "B"], + pssm_paths={ + "A": "tests/data/pssm/3C8P/3C8P.A.pdb.pssm", + "B": "tests/data/pssm/3C8P/3C8P.B.pdb.pssm", + }, + ) + ] * n_queries + elif query_type == "srv": + queries = [ + SingleResidueVariantQuery( + pdb_path="tests/data/pdb/101M/101M.pdb", + resolution="residue", + chain_ids="A", + variant_residue_number=None, # placeholder + insertion_code=None, + wildtype_amino_acid=alanine, + variant_amino_acid=phenylalanine, + pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, + ) + ] * n_queries else: raise ValueError("Please insert a valid type (either ppi or srv).") @@ -63,13 +68,18 @@ def _querycollection_tester( # pylint: disable=dangerous-default-value collection = QueryCollection() for idx in range(n_queries): - if query_type == 'srv': + if query_type == "srv": queries[idx].variant_residue_number = idx + 1 collection.add(queries[idx]) else: collection.add(queries[idx], warn_duplicate=False) - output_paths = collection.process(prefix, feature_modules, cpu_count, combine_output) + output_paths = collection.process( + prefix, + feature_modules, + cpu_count, + combine_output, + ) assert len(output_paths) > 0 graph_names = [] @@ -88,15 +98,14 @@ def _assert_correct_modules( features: str | list[str], absent: str, ): - """Helper function to assert inclusion of correct features + """Helper function to assert inclusion of correct features. Args: output_paths (str): output_paths as returned from _querycollection_tester - features (str | list[str]]: feature(s) that should be present + features (str | list[str]): feature(s) that should be present absent (str): feature that should be absent """ - - if isinstance(features,str): + if isinstance(features, str): features = [features] with h5py.File(output_paths[0], "r") as f5: @@ -104,24 +113,21 @@ def _assert_correct_modules( for feat in features: try: if feat == Efeat.DISTANCE: - _ = f5[list(f5.keys())[0]][f"{Efeat.EDGE}/{feat}"] + _ = f5[next(iter(f5.keys()))][f"{Efeat.EDGE}/{feat}"] else: - _ = f5[list(f5.keys())[0]][f"{Nfeat.NODE}/{feat}"] + _ = f5[next(iter(f5.keys()))][f"{Nfeat.NODE}/{feat}"] except KeyError: missing.append(feat) if missing: - raise KeyError(f'The following feature(s) were not created: {missing}.') + raise KeyError(f"The following feature(s) were not created: {missing}.") with pytest.raises(KeyError): - _ = f5[list(f5.keys())[0]][f"{Nfeat.NODE}/{absent}"] + _ = f5[next(iter(f5.keys()))][f"{Nfeat.NODE}/{absent}"] def test_querycollection_process(): - """ - Tests processing method of QueryCollection class. - """ - - for query_type in ['ppi', 'srv']: + """Tests processing method of QueryCollection class.""" + for query_type in ["ppi", "srv"]: n_queries = 3 n_queries = 3 @@ -140,61 +146,67 @@ def test_querycollection_process_single_feature_module(): Tested for following input types: ModuleType, list[ModuleType] str, list[str] """ - - for query_type in ['ppi', 'srv']: - for testcase in [surfacearea, [surfacearea], 'surfacearea', ['surfacearea']]: + for query_type in ["ppi", "srv"]: + for testcase in [surfacearea, [surfacearea], "surfacearea", ["surfacearea"]]: _, output_directory, output_paths = _querycollection_tester(query_type, feature_modules=testcase) _assert_correct_modules(output_paths, Nfeat.BSA, Nfeat.HSE) rmtree(output_directory) def test_querycollection_process_all_features_modules(): - """ - Tests processing for generating all features. - """ - - one_feature_from_each_module = [Nfeat.RESTYPE, Nfeat.PSSM, Efeat.DISTANCE, Nfeat.HSE, Nfeat.SECSTRUCT, Nfeat.BSA, Nfeat.IRCTOTAL] - - _, output_directory, output_paths = _querycollection_tester('ppi', feature_modules='all') - _assert_correct_modules(output_paths, one_feature_from_each_module, 'dummy_feature') + """Tests processing for generating all features.""" + one_feature_from_each_module = [ + Nfeat.RESTYPE, + Nfeat.PSSM, + Efeat.DISTANCE, + Nfeat.HSE, + Nfeat.SECSTRUCT, + Nfeat.BSA, + Nfeat.IRCTOTAL, + ] + + _, output_directory, output_paths = _querycollection_tester("ppi", feature_modules="all") + _assert_correct_modules(output_paths, one_feature_from_each_module, "dummy_feature") rmtree(output_directory) - _, output_directory, output_paths = _querycollection_tester('srv', feature_modules='all') - _assert_correct_modules(output_paths, one_feature_from_each_module[:-1], Nfeat.IRCTOTAL) + _, output_directory, output_paths = _querycollection_tester("srv", feature_modules="all") + _assert_correct_modules( + output_paths, + one_feature_from_each_module[:-1], + Nfeat.IRCTOTAL, + ) rmtree(output_directory) def test_querycollection_process_default_features_modules(): - """ - Tests processing for generating all features. - """ - - for query_type in ['ppi', 'srv']: + """Tests processing for generating all features.""" + for query_type in ["ppi", "srv"]: _, output_directory, output_paths = _querycollection_tester(query_type) - _assert_correct_modules(output_paths, [Nfeat.RESTYPE, Efeat.DISTANCE], Nfeat.HSE) + _assert_correct_modules( + output_paths, + [Nfeat.RESTYPE, Efeat.DISTANCE], + Nfeat.HSE, + ) rmtree(output_directory) def test_querycollection_process_combine_output_true(): - """ - Tests processing for combining hdf5 files into one. - """ - - for query_type in ['ppi', 'srv']: + """Tests processing for combining hdf5 files into one.""" + for query_type in ["ppi", "srv"]: modules = [surfacearea, components] _, output_directory_t, output_paths_t = _querycollection_tester(query_type, feature_modules=modules) - _, output_directory_f, output_paths_f = _querycollection_tester(query_type, feature_modules=modules, combine_output = False, cpu_count=2) + _, output_directory_f, output_paths_f = _querycollection_tester(query_type, feature_modules=modules, combine_output=False, cpu_count=2) assert len(output_paths_t) == 1 keys_t = {} - with h5py.File(output_paths_t[0],'r') as file_t: + with h5py.File(output_paths_t[0], "r") as file_t: for key, value in file_t.items(): keys_t[key] = value keys_f = {} for output_path in output_paths_f: - with h5py.File(output_path,'r') as file_f: + with h5py.File(output_path, "r") as file_f: for key, value in file_f.items(): keys_f[key] = value assert keys_t == keys_f @@ -204,20 +216,17 @@ def test_querycollection_process_combine_output_true(): def test_querycollection_process_combine_output_false(): - """ - Tests processing for keeping all generated hdf5 files . - """ - - for query_type in ['ppi', 'srv']: + """Tests processing for keeping all generated hdf5 files .""" + for query_type in ["ppi", "srv"]: cpu_count = 2 combine_output = False modules = [surfacearea, components] _, output_directory, output_paths = _querycollection_tester( - query_type, - feature_modules=modules, - cpu_count = cpu_count, - combine_output = combine_output, - ) + query_type, + feature_modules=modules, + cpu_count=cpu_count, + combine_output=combine_output, + ) assert len(output_paths) == cpu_count rmtree(output_directory) @@ -225,7 +234,6 @@ def test_querycollection_process_combine_output_false(): def test_querycollection_duplicates_add(): """Tests add method of QueryCollection class.""" - ref_path = "tests/data/ref/1ATN/1ATN.pdb" pssm_path1 = "tests/data/pssm/1ATN/1ATN.A.pdb.pssm" pssm_path2 = "tests/data/pssm/1ATN/1ATN.B.pdb.pssm" @@ -237,7 +245,8 @@ def test_querycollection_duplicates_add(): "tests/data/pdb/1ATN/1ATN_1w.pdb", "tests/data/pdb/1ATN/1ATN_2w.pdb", "tests/data/pdb/1ATN/1ATN_2w.pdb", - "tests/data/pdb/1ATN/1ATN_3w.pdb"] + "tests/data/pdb/1ATN/1ATN_3w.pdb", + ] queries = QueryCollection() @@ -245,25 +254,28 @@ def test_querycollection_duplicates_add(): for pdb_path in pdb_paths: # Append data points targets = compute_ppi_scores(pdb_path, ref_path) - queries.add(ProteinProteinInterfaceQuery( - pdb_path = pdb_path, - resolution="residue", - chain_ids = [chain_id1, chain_id2], - targets = targets, - pssm_paths = { - chain_id1: pssm_path1, - chain_id2: pssm_path2 - } - )) - - #check id naming for all pdb files - model_ids = [] - for query in queries.queries: - model_ids.append(query.model_id) + queries.add( + ProteinProteinInterfaceQuery( + pdb_path=pdb_path, + resolution="residue", + chain_ids=[chain_id1, chain_id2], + targets=targets, + pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}, + ) + ) + + # check id naming for all pdb files + model_ids = [query.model_id for query in queries.queries] model_ids.sort() - # pylint: disable=protected-access - assert model_ids == ['1ATN_1w', '1ATN_1w_2', '1ATN_1w_3', '1ATN_2w', '1ATN_2w_2', '1ATN_3w'] - assert queries._ids_count['residue-ppi:A-B:1ATN_1w'] == 3 - assert queries._ids_count['residue-ppi:A-B:1ATN_2w'] == 2 - assert queries._ids_count['residue-ppi:A-B:1ATN_3w'] == 1 + assert model_ids == [ + "1ATN_1w", + "1ATN_1w_2", + "1ATN_1w_3", + "1ATN_2w", + "1ATN_2w_2", + "1ATN_3w", + ] + assert queries._ids_count["residue-ppi:A-B:1ATN_1w"] == 3 # noqa: SLF001 (private member accessed) + assert queries._ids_count["residue-ppi:A-B:1ATN_2w"] == 2 # noqa: SLF001 (private member accessed) + assert queries._ids_count["residue-ppi:A-B:1ATN_3w"] == 1 # noqa: SLF001 (private member accessed) diff --git a/tests/test_set_lossfunction.py b/tests/test_set_lossfunction.py index 1ad42a732..f495d0f0e 100644 --- a/tests/test_set_lossfunction.py +++ b/tests/test_set_lossfunction.py @@ -12,30 +12,34 @@ from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork from deeprank2.trainer import Trainer -hdf5_path = 'tests/data/hdf5/test.hdf5' +hdf5_path = "tests/data/hdf5/test.hdf5" -def base_test(model_path, trainer: Trainer, lossfunction = None, override = False): +def base_test( + model_path, + trainer: Trainer, + lossfunction=None, + override=False, +): if lossfunction: - trainer.set_lossfunction(lossfunction = lossfunction, override_invalid=override) + trainer.set_lossfunction(lossfunction=lossfunction, override_invalid=override) # check correct passing to/picking up from pretrained model with warnings.catch_warnings(record=UserWarning): trainer.train(nepoch=2, best_model=False, filename=model_path) - trainer_pretrained = Trainer( - neuralnet = NaiveNetwork, + return Trainer( + neuralnet=NaiveNetwork, dataset_test=trainer.dataset_train, - pretrained_model=model_path) - - return trainer_pretrained + pretrained_model=model_path, + ) class TestLosses(unittest.TestCase): @classmethod def setUpClass(class_): class_.work_directory = tempfile.mkdtemp() - class_.save_path = class_.work_directory + 'test.tar' + class_.save_path = class_.work_directory + "test.tar" @classmethod def tearDownClass(class_): @@ -43,24 +47,27 @@ def tearDownClass(class_): # Classification tasks def test_classif_default(self): - dataset = GraphDataset(hdf5_path, - target = targets.BINARY) + dataset = GraphDataset( + hdf5_path, + target=targets.BINARY, + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) trainer_pretrained = base_test(self.save_path, trainer) assert isinstance(trainer.lossfunction, nn.CrossEntropyLoss) assert isinstance(trainer_pretrained.lossfunction, nn.CrossEntropyLoss) - def test_classif_all(self): - dataset = GraphDataset(hdf5_path, - target = targets.BINARY) + dataset = GraphDataset( + hdf5_path, + target=targets.BINARY, + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) # only NLLLoss and CrossEntropyLoss are currently working @@ -70,14 +77,15 @@ def test_classif_all(self): assert isinstance(trainer.lossfunction, lossfunction) assert isinstance(trainer_pretrained.lossfunction, lossfunction) - def test_classif_weighted(self): - dataset = GraphDataset(hdf5_path, - target = targets.BINARY) + dataset = GraphDataset( + hdf5_path, + target=targets.BINARY, + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, - class_weights = True + neuralnet=NaiveNetwork, + dataset_train=dataset, + class_weights=True, ) lossfunction = nn.NLLLoss @@ -86,7 +94,6 @@ def test_classif_weighted(self): assert isinstance(trainer_pretrained.lossfunction, lossfunction) assert trainer_pretrained.class_weights - # def test_classif_invalid_weighted(self): # dataset = GraphDataset(hdf5_path, # target=targets.BINARY) @@ -101,54 +108,61 @@ def test_classif_weighted(self): # with pytest.raises(ValueError): # base_test(self.save_path, trainer, lossfunction) - def test_classif_invalid_lossfunction(self): - dataset = GraphDataset(hdf5_path, - target = targets.BINARY) + dataset = GraphDataset( + hdf5_path, + target=targets.BINARY, + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) lossfunction = nn.MSELoss with pytest.raises(ValueError): base_test(self.save_path, trainer, lossfunction) - def test_classif_invalid_lossfunction_override(self): - dataset = GraphDataset(hdf5_path, - target = targets.BINARY) + dataset = GraphDataset(hdf5_path, target=targets.BINARY) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) lossfunction = nn.MSELoss with pytest.raises(RuntimeError): - base_test(self.save_path, trainer, lossfunction, override = True) - + base_test( + self.save_path, + trainer, + lossfunction, + override=True, + ) # Regression tasks def test_regress_default(self): - dataset = GraphDataset(hdf5_path, - target = 'BA', - task = 'regress') + dataset = GraphDataset( + hdf5_path, + target="BA", + task="regress", + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) trainer_pretrained = base_test(self.save_path, trainer) assert isinstance(trainer.lossfunction, nn.MSELoss) assert isinstance(trainer_pretrained.lossfunction, nn.MSELoss) - def test_regress_all(self): - dataset = GraphDataset(hdf5_path, - target = 'BA', task = 'regress') + dataset = GraphDataset( + hdf5_path, + target="BA", + task="regress", + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) for f in losses.regression_losses: lossfunction = f @@ -157,26 +171,30 @@ def test_regress_all(self): assert isinstance(trainer.lossfunction, lossfunction) assert isinstance(trainer_pretrained.lossfunction, lossfunction) - def test_regress_invalid_lossfunction(self): - dataset = GraphDataset(hdf5_path, - target = 'BA', task = 'regress') + dataset = GraphDataset( + hdf5_path, + target="BA", + task="regress", + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) lossfunction = nn.CrossEntropyLoss with pytest.raises(ValueError): base_test(self.save_path, trainer, lossfunction) - def test_regress_invalid_lossfunction_override(self): - dataset = GraphDataset(hdf5_path, - target = 'BA', task = 'regress') + dataset = GraphDataset( + hdf5_path, + target="BA", + task="regress", + ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) lossfunction = nn.CrossEntropyLoss diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 2cf6e151f..96abc7ae2 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,14 +21,24 @@ from deeprank2.neuralnets.gnn.naive_gnn import NaiveNetwork from deeprank2.neuralnets.gnn.sgat import SGAT from deeprank2.trainer import Trainer, _divide_dataset -from deeprank2.utils.exporters import (HDF5OutputExporter, ScatterPlotExporter, - TensorboardBinaryClassificationExporter) +from deeprank2.utils.exporters import HDF5OutputExporter, ScatterPlotExporter, TensorboardBinaryClassificationExporter + +# ruff: noqa: FBT003 _log = logging.getLogger(__name__) -default_features = [Nfeat.RESTYPE, Nfeat.POLARITY, Nfeat.BSA, Nfeat.RESDEPTH, Nfeat.HSE, Nfeat.INFOCONTENT, Nfeat.PSSM] +default_features = [ + Nfeat.RESTYPE, + Nfeat.POLARITY, + Nfeat.BSA, + Nfeat.RESDEPTH, + Nfeat.HSE, + Nfeat.INFOCONTENT, + Nfeat.PSSM, +] + -def _model_base_test( # pylint: disable=too-many-arguments, too-many-locals +def _model_base_test( save_path, model_class, train_hdf5_path, @@ -41,34 +51,33 @@ def _model_base_test( # pylint: disable=too-many-arguments, too-many-locals target_transform, output_exporters, clustering_method, - use_cuda = False + use_cuda=False, ): - dataset_train = GraphDataset( - hdf5_path = train_hdf5_path, - node_features = node_features, - edge_features = edge_features, - clustering_method = clustering_method, - target = target, - target_transform = target_transform, - task = task - ) + hdf5_path=train_hdf5_path, + node_features=node_features, + edge_features=edge_features, + clustering_method=clustering_method, + target=target, + target_transform=target_transform, + task=task, + ) if val_hdf5_path is not None: dataset_val = GraphDataset( - hdf5_path = val_hdf5_path, - train_source = dataset_train, - clustering_method = clustering_method, - ) + hdf5_path=val_hdf5_path, + train_source=dataset_train, + clustering_method=clustering_method, + ) else: dataset_val = None if test_hdf5_path is not None: dataset_test = GraphDataset( - hdf5_path = test_hdf5_path, - train_source = dataset_train, - clustering_method = clustering_method, - ) + hdf5_path=test_hdf5_path, + train_source=dataset_train, + clustering_method=clustering_method, + ) else: dataset_test = None @@ -87,31 +96,41 @@ def _model_base_test( # pylint: disable=too-many-arguments, too-many-locals data = dataset_train.get(0) - for name, data_tensor in (("x", data.x), ("y", data.y), - (Efeat.INDEX, data.edge_index), - ("edge_attr", data.edge_attr), - (Nfeat.POSITION, data.pos), - ("cluster0",data.cluster0), - ("cluster1", data.cluster1)): - + for name, data_tensor in ( + ("x", data.x), + ("y", data.y), + (Efeat.INDEX, data.edge_index), + ("edge_attr", data.edge_attr), + (Nfeat.POSITION, data.pos), + ("cluster0", data.cluster0), + ("cluster1", data.cluster1), + ): if data_tensor is not None: assert data_tensor.is_cuda, f"data.{name} is not cuda" with warnings.catch_warnings(record=UserWarning): - trainer.train(nepoch=3, batch_size=64, validate=True, best_model=False, filename=save_path) + trainer.train( + nepoch=3, + batch_size=64, + validate=True, + best_model=False, + filename=save_path, + ) Trainer( model_class, dataset_train, dataset_val, dataset_test, - pretrained_model=save_path) + pretrained_model=save_path, + ) + class TestTrainer(unittest.TestCase): @classmethod def setUpClass(class_): class_.work_directory = tempfile.mkdtemp() - class_.save_path = class_.work_directory + 'test.tar' + class_.save_path = class_.work_directory + "test.tar" @classmethod def tearDownClass(class_): @@ -119,34 +138,33 @@ def tearDownClass(class_): def test_grid_regression(self): dataset = GridDataset( - hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5", - subset = None, - features = [Efeat.VDW], - target = targets.IRMSD, - task = targets.REGRESS - ) - trainer = Trainer( - CnnRegression, - dataset + hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", + subset=None, + features=[Efeat.VDW], + target=targets.IRMSD, + task=targets.REGRESS, ) + trainer = Trainer(CnnRegression, dataset) trainer.train(nepoch=1, batch_size=2, best_model=False, filename=None) def test_grid_classification(self): dataset = GridDataset( - hdf5_path = "tests/data/hdf5/1ATN_ppi.hdf5", - subset = None, - features = [Efeat.VDW], - target = targets.BINARY, - task = targets.CLASSIF - ) - trainer = Trainer( - CnnClassification, - dataset + hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", + subset=None, + features=[Efeat.VDW], + target=targets.BINARY, + task=targets.CLASSIF, + ) + trainer = Trainer(CnnClassification, dataset) + trainer.train( + nepoch=1, + batch_size=2, + best_model=False, + filename=None, ) - trainer.train(nepoch=1, batch_size = 2, best_model=False, filename=None) def test_ginet_sigmoid(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -168,7 +186,7 @@ def test_ginet_sigmoid(self): assert len(os.listdir(self.work_directory)) > 0 def test_ginet(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -190,7 +208,7 @@ def test_ginet(self): assert len(os.listdir(self.work_directory)) > 0 def test_ginet_class(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -213,7 +231,7 @@ def test_ginet_class(self): assert len(os.listdir(self.work_directory)) > 0 def test_fout(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -235,7 +253,7 @@ def test_fout(self): assert len(os.listdir(self.work_directory)) > 0 def test_sgat(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -257,7 +275,7 @@ def test_sgat(self): assert len(os.listdir(self.work_directory)) > 0 def test_naive(self): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -303,7 +321,13 @@ def test_incompatible_classification(self): "tests/data/hdf5/variants.hdf5", "tests/data/hdf5/variants.hdf5", "tests/data/hdf5/variants.hdf5", - [Nfeat.RESSIZE, Nfeat.POLARITY, Nfeat.SASA, Nfeat.INFOCONTENT, Nfeat.PSSM], + [ + Nfeat.RESSIZE, + Nfeat.POLARITY, + Nfeat.SASA, + Nfeat.INFOCONTENT, + Nfeat.PSSM, + ], [Efeat.DISTANCE], targets.CLASSIF, targets.BINARY, @@ -313,141 +337,126 @@ def test_incompatible_classification(self): ) def test_incompatible_no_pretrained_no_train(self): + dataset = GraphDataset( + hdf5_path="tests/data/hdf5/test.hdf5", + target=targets.BINARY, + ) + with pytest.raises(ValueError): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY, - ) Trainer( - neuralnet = NaiveNetwork, - dataset_test = dataset, + neuralnet=NaiveNetwork, + dataset_test=dataset, ) def test_incompatible_no_pretrained_no_Net(self): with pytest.raises(ValueError): - dataset = GraphDataset( + _ = GraphDataset( hdf5_path="tests/data/hdf5/test.hdf5", ) - Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, - ) def test_incompatible_no_pretrained_no_target(self): + dataset = GraphDataset( + hdf5_path="tests/data/hdf5/test.hdf5", + target=targets.BINARY, + ) with pytest.raises(ValueError): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY, - ) Trainer( - dataset_train = dataset, + dataset_train=dataset, ) def test_incompatible_pretrained_no_test(self): + dataset = GraphDataset( + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, + ) + trainer = Trainer( + neuralnet=GINet, + dataset_train=dataset, + ) + + with warnings.catch_warnings(record=UserWarning): + trainer.train(nepoch=3, validate=True, best_model=False, filename=self.save_path) with pytest.raises(ValueError): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY - ) - trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset, + Trainer( + neuralnet=GINet, + dataset_train=dataset, + pretrained_model=self.save_path, ) - with warnings.catch_warnings(record=UserWarning): - trainer.train(nepoch=3, validate=True, best_model=False, filename=self.save_path) - Trainer( - neuralnet = GINet, - dataset_train = dataset, - pretrained_model = self.save_path - ) - def test_incompatible_pretrained_no_Net(self): - with pytest.raises(ValueError): - dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY - ) - trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset, - ) + dataset = GraphDataset( + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, + ) + trainer = Trainer( + neuralnet=GINet, + dataset_train=dataset, + ) - with warnings.catch_warnings(record=UserWarning): - trainer.train(nepoch=3, validate=True, best_model=False, filename=self.save_path) - Trainer( - dataset_test = dataset, - pretrained_model = self.save_path - ) + with warnings.catch_warnings(record=UserWarning): + trainer.train(nepoch=3, validate=True, best_model=False, filename=self.save_path) + with pytest.raises(ValueError): + Trainer(dataset_test=dataset, pretrained_model=self.save_path) def test_no_training_no_pretrained(self): dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY, - ) - dataset_val = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train - ) - dataset_test = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, ) + dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) + dataset_test = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset_train, - dataset_val = dataset_val, - dataset_test = dataset_test + neuralnet=GINet, + dataset_train=dataset_train, + dataset_val=dataset_val, + dataset_test=dataset_test, ) with pytest.raises(ValueError): trainer.test() def test_no_valid_provided(self): dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY, + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, ) trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset, + neuralnet=GINet, + dataset_train=dataset, ) - trainer.train(batch_size = 1, best_model=False, filename=None) + trainer.train(batch_size=1, best_model=False, filename=None) assert len(trainer.train_loader) == int(0.75 * len(dataset)) assert len(trainer.valid_loader) == int(0.25 * len(dataset)) def test_no_test_provided(self): dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY, - ) - dataset_val = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, ) + dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset_train, - dataset_val = dataset_val, + neuralnet=GINet, + dataset_train=dataset_train, + dataset_val=dataset_val, ) - trainer.train(batch_size = 1, best_model=False, filename=None) + trainer.train(batch_size=1, best_model=False, filename=None) with pytest.raises(ValueError): trainer.test() def test_no_valid_full_train(self): dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - clustering_method = "mcl", - target = targets.BINARY + hdf5_path="tests/data/hdf5/test.hdf5", + clustering_method="mcl", + target=targets.BINARY, ) trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset, - val_size = 0 + neuralnet=GINet, + dataset_train=dataset, + val_size=0, ) trainer.train(batch_size=1, best_model=False, filename=None) assert len(trainer.train_loader) == len(dataset) @@ -455,12 +464,12 @@ def test_no_valid_full_train(self): def test_optim(self): dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY, + hdf5_path="tests/data/hdf5/test.hdf5", + target=targets.BINARY, ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) optimizer = torch.optim.Adamax @@ -475,9 +484,10 @@ def test_optim(self): with warnings.catch_warnings(record=UserWarning): trainer.train(nepoch=3, best_model=False, filename=self.save_path) trainer_pretrained = Trainer( - neuralnet = NaiveNetwork, + neuralnet=NaiveNetwork, dataset_test=dataset, - pretrained_model=self.save_path) + pretrained_model=self.save_path, + ) assert str(type(trainer_pretrained.optimizer)) == "" assert trainer_pretrained.lr == lr @@ -485,21 +495,21 @@ def test_optim(self): def test_default_optim(self): dataset = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY, + hdf5_path="tests/data/hdf5/test.hdf5", + target=targets.BINARY, ) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_train = dataset, + neuralnet=NaiveNetwork, + dataset_train=dataset, ) assert isinstance(trainer.optimizer, torch.optim.Adam) assert trainer.lr == 0.001 assert trainer.weight_decay == 1e-05 - def test_cuda(self): # test_ginet, but with cuda + def test_cuda(self): # test_ginet, but with cuda if torch.cuda.is_available(): - files = glob.glob(self.work_directory + '/*') + files = glob.glob(self.work_directory + "/*") for f in files: os.remove(f) assert len(os.listdir(self.work_directory)) == 0 @@ -517,7 +527,7 @@ def test_cuda(self): # test_ginet, but with cuda False, [HDF5OutputExporter(self.work_directory)], "mcl", - True + True, ) assert len(os.listdir(self.work_directory)) > 0 @@ -529,48 +539,40 @@ def test_dataset_equivalence_no_pretrained(self): # TestCase: dataset_train set (no pretrained model assigned). # Raise error when train dataset is neither a GraphDataset or GridDataset. + dataset_invalid_train = GINet(input_shape=2) with pytest.raises(TypeError): - dataset_invalid_train = GINet( - input_shape = 2 - ) Trainer( - neuralnet = GINet, - dataset_train = dataset_invalid_train, + neuralnet=GINet, + dataset_train=dataset_invalid_train, ) dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - edge_features = [Efeat.DISTANCE, Efeat.COVALENT], - target = targets.BINARY + hdf5_path="tests/data/hdf5/test.hdf5", + edge_features=[Efeat.DISTANCE, Efeat.COVALENT], + target=targets.BINARY, ) # Raise error when train_source parameter in GraphDataset/GridDataset # is not equivalent to the dataset_train passed to Trainer. dataset_train_other = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - edge_features = [Efeat.SAMECHAIN, Efeat.COVALENT], - target = 'BA', - task='regress' - ) - dataset_val = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train - ) - dataset_test = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train + hdf5_path="tests/data/hdf5/test.hdf5", + edge_features=[Efeat.SAMECHAIN, Efeat.COVALENT], + target="BA", + task="regress", ) + dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) + dataset_test = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) with pytest.raises(ValueError): Trainer( - neuralnet = GINet, - dataset_train = dataset_train_other, - dataset_val = dataset_val + neuralnet=GINet, + dataset_train=dataset_train_other, + dataset_val=dataset_val, ) with pytest.raises(ValueError): Trainer( - neuralnet = GINet, - dataset_train = dataset_train_other, - dataset_test = dataset_test + neuralnet=GINet, + dataset_train=dataset_train_other, + dataset_test=dataset_test, ) def test_dataset_equivalence_pretrained(self): @@ -578,38 +580,35 @@ def test_dataset_equivalence_pretrained(self): # Raise error when no dataset_test is set. dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - edge_features = [Efeat.DISTANCE, Efeat.COVALENT], - clustering_method = "mcl", - target = targets.BINARY + hdf5_path="tests/data/hdf5/test.hdf5", + edge_features=[Efeat.DISTANCE, Efeat.COVALENT], + clustering_method="mcl", + target=targets.BINARY, ) trainer = Trainer( - neuralnet = GINet, - dataset_train = dataset_train, + neuralnet=GINet, + dataset_train=dataset_train, ) + with warnings.catch_warnings(record=UserWarning): + # train pretrained model + trainer.train(nepoch=3, validate=True, best_model=False, filename=self.save_path) + # pretrained model assigned(no dataset_train needed) with pytest.raises(ValueError): - with warnings.catch_warnings(record = UserWarning): - #train pretrained model - trainer.train(nepoch = 3, validate = True, best_model = False, filename = self.save_path) - #pretrained model assigned(no dataset_train needed) - Trainer( - neuralnet = GINet, - pretrained_model = self.save_path - ) + Trainer(neuralnet=GINet, pretrained_model=self.save_path) def test_trainsize(self): hdf5 = "tests/data/hdf5/train.hdf5" - hdf5_file = h5py.File(hdf5, 'r') # contains 44 datapoints - n_val = int ( 0.25 * len(hdf5_file) ) + hdf5_file = h5py.File(hdf5, "r") # contains 44 datapoints + n_val = int(0.25 * len(hdf5_file)) n_train = len(hdf5_file) - n_val test_cases = [None, 0.25, n_val] for t in test_cases: - dataset_train, dataset_val =_divide_dataset( - dataset = GraphDataset(hdf5_path = hdf5, target = targets.BINARY), - splitsize = t, + dataset_train, dataset_val = _divide_dataset( + dataset=GraphDataset(hdf5_path=hdf5, target=targets.BINARY), + splitsize=t, ) assert len(dataset_train) == n_train assert len(dataset_val) == n_val @@ -618,79 +617,66 @@ def test_trainsize(self): def test_invalid_trainsize(self): hdf5 = "tests/data/hdf5/train.hdf5" - hdf5_file = h5py.File(hdf5, 'r') # contains 44 datapoints + hdf5_file = h5py.File(hdf5, "r") # contains 44 datapoints n = len(hdf5_file) test_cases = [ - 1.0, n, # cannot be 100% validation data - -0.5, -1, # no negative values - 1.1, n + 1, # cannot use more than all data as input - ] + 1.0, + n, # fmt: skip; cannot be 100% validation data + -0.5, + -1, # fmt: skip; no negative values + 1.1, + n + 1, # fmt: skip; cannot use more than all data as input + ] for t in test_cases: print(t) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _divide_dataset( - dataset = GraphDataset(hdf5_path = hdf5), - splitsize = t, + dataset=GraphDataset(hdf5_path=hdf5), + splitsize=t, ) hdf5_file.close() def test_invalid_cuda_ngpus(self): - dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY - ) - dataset_val = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train - ) + dataset_train = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY) + dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) with pytest.raises(ValueError): Trainer( - neuralnet = GINet, - dataset_train = dataset_train, - dataset_val = dataset_val, - ngpu = 2 + neuralnet=GINet, + dataset_train=dataset_train, + dataset_val=dataset_val, + ngpu=2, ) def test_invalid_no_cuda_available(self): if not torch.cuda.is_available(): - dataset_train = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - target = targets.BINARY - ) - dataset_val = GraphDataset( - hdf5_path = "tests/data/hdf5/test.hdf5", - train_source = dataset_train - ) + dataset_train = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", target=targets.BINARY) + dataset_val = GraphDataset(hdf5_path="tests/data/hdf5/test.hdf5", train_source=dataset_train) with pytest.raises(ValueError): Trainer( - neuralnet = GINet, - dataset_train = dataset_train, - dataset_val = dataset_val, - cuda = True + neuralnet=GINet, + dataset_train=dataset_train, + dataset_val=dataset_val, + cuda=True, ) else: - warnings.warn('CUDA is available; test_invalid_no_cuda_available was skipped') - _log.info('CUDA is available; test_invalid_no_cuda_available was skipped') + warnings.warn("CUDA is available; test_invalid_no_cuda_available was skipped") + _log.info("CUDA is available; test_invalid_no_cuda_available was skipped") def test_train_method_no_train(self): - # Graphs data test_data_graph = "tests/data/hdf5/test.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" - dataset_test = GraphDataset( - hdf5_path = test_data_graph, - train_source = pretrained_model_graph - ) + dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_test = dataset_test, - pretrained_model = pretrained_model_graph + neuralnet=NaiveNetwork, + dataset_test=dataset_test, + pretrained_model=pretrained_model_graph, ) with pytest.raises(ValueError): @@ -700,36 +686,29 @@ def test_train_method_no_train(self): test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5" pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar" - dataset_test = GridDataset( - hdf5_path = test_data_grid, - train_source = pretrained_model_grid - ) + dataset_test = GridDataset(hdf5_path=test_data_grid, train_source=pretrained_model_grid) trainer = Trainer( - neuralnet = CnnClassification, - dataset_test = dataset_test, - pretrained_model = pretrained_model_grid + neuralnet=CnnClassification, + dataset_test=dataset_test, + pretrained_model=pretrained_model_grid, ) with pytest.raises(ValueError): trainer.train() def test_test_method_pretrained_model_on_dataset_with_target(self): - # Graphs data test_data_graph = "tests/data/hdf5/test.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" - dataset_test = GraphDataset( - hdf5_path = test_data_graph, - train_source = pretrained_model_graph - ) + dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_test = dataset_test, - pretrained_model = pretrained_model_graph, - output_exporters = [HDF5OutputExporter("./")] - ) + neuralnet=NaiveNetwork, + dataset_test=dataset_test, + pretrained_model=pretrained_model_graph, + output_exporters=[HDF5OutputExporter("./")], + ) trainer.test() @@ -740,17 +719,14 @@ def test_test_method_pretrained_model_on_dataset_with_target(self): test_data_grid = "tests/data/hdf5/1ATN_ppi.hdf5" pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar" - dataset_test = GridDataset( - hdf5_path = test_data_grid, - train_source = pretrained_model_grid - ) + dataset_test = GridDataset(hdf5_path=test_data_grid, train_source=pretrained_model_grid) trainer = Trainer( - neuralnet = CnnClassification, - dataset_test = dataset_test, - pretrained_model = pretrained_model_grid, - output_exporters = [HDF5OutputExporter("./")] - ) + neuralnet=CnnClassification, + dataset_test=dataset_test, + pretrained_model=pretrained_model_grid, + output_exporters=[HDF5OutputExporter("./")], + ) trainer.test() @@ -762,17 +738,14 @@ def test_test_method_pretrained_model_on_dataset_without_target(self): test_data_graph = "tests/data/hdf5/test_no_target.hdf5" pretrained_model_graph = "tests/data/pretrained/testing_graph_model.pth.tar" - dataset_test = GraphDataset( - hdf5_path = test_data_graph, - train_source = pretrained_model_graph - ) + dataset_test = GraphDataset(hdf5_path=test_data_graph, train_source=pretrained_model_graph) trainer = Trainer( - neuralnet = NaiveNetwork, - dataset_test = dataset_test, - pretrained_model = pretrained_model_graph, - output_exporters = [HDF5OutputExporter("./")] - ) + neuralnet=NaiveNetwork, + dataset_test=dataset_test, + pretrained_model=pretrained_model_graph, + output_exporters=[HDF5OutputExporter("./")], + ) trainer.test() @@ -785,17 +758,14 @@ def test_test_method_pretrained_model_on_dataset_without_target(self): test_data_grid = "tests/data/hdf5/test_no_target.hdf5" pretrained_model_grid = "tests/data/pretrained/testing_grid_model.pth.tar" - dataset_test = GridDataset( - hdf5_path = test_data_grid, - train_source = pretrained_model_grid - ) + dataset_test = GridDataset(hdf5_path=test_data_grid, train_source=pretrained_model_grid) trainer = Trainer( - neuralnet = CnnClassification, - dataset_test = dataset_test, - pretrained_model = pretrained_model_grid, - output_exporters = [HDF5OutputExporter("./")] - ) + neuralnet=CnnClassification, + dataset_test=dataset_test, + pretrained_model=pretrained_model_grid, + output_exporters=[HDF5OutputExporter("./")], + ) trainer.test() diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/test_target.py b/tests/tools/test_target.py index bf5af2420..32dac919a 100644 --- a/tests/tools/test_target.py +++ b/tests/tools/test_target.py @@ -33,43 +33,46 @@ def test_add_target(self): os.remove(target_path) os.remove(graph_path) - def test_compute_ppi_scores(self): scores = compute_ppi_scores( os.path.join(self.pdb_path, "1ATN_1w.pdb"), - os.path.join(self.ref, "1ATN.pdb")) + os.path.join(self.ref, "1ATN.pdb"), + ) sim = StructureSimilarity( os.path.join(self.pdb_path, "1ATN_1w.pdb"), - os.path.join(self.ref, "1ATN.pdb"), enforce_residue_matching=False) + os.path.join(self.ref, "1ATN.pdb"), + enforce_residue_matching=False, + ) lrmsd = sim.compute_lrmsd_fast(method="svd") irmsd = sim.compute_irmsd_fast(method="svd") fnat = sim.compute_fnat_fast() dockq = sim.compute_DockQScore(fnat, lrmsd, irmsd) binary = irmsd < 4.0 capri = 4 - for thr, val in zip([6.0, 4.0, 2.0, 1.0], [4, 3, 2, 1]): + for thr, val in zip([6.0, 4.0, 2.0, 1.0], [4, 3, 2, 1], strict=True): if irmsd < thr: capri = val - assert scores['irmsd'] == irmsd - assert scores['lrmsd'] == lrmsd - assert scores['fnat'] == fnat - assert scores['dockq'] == dockq - assert scores['binary'] == binary - assert scores['capri_class'] == capri + assert scores["irmsd"] == irmsd + assert scores["lrmsd"] == lrmsd + assert scores["fnat"] == fnat + assert scores["dockq"] == dockq + assert scores["binary"] == binary + assert scores["capri_class"] == capri def test_compute_ppi_scores_same_struct(self): scores = compute_ppi_scores( os.path.join(self.pdb_path, "1ATN_1w.pdb"), - os.path.join(self.pdb_path, "1ATN_1w.pdb")) + os.path.join(self.pdb_path, "1ATN_1w.pdb"), + ) - assert scores['irmsd'] == 0.0 - assert scores['lrmsd'] == 0.0 - assert scores['fnat'] == 1.0 - assert scores['dockq'] == 1.0 - assert scores['binary'] # True - assert scores['capri_class'] == 1 + assert scores["irmsd"] == 0.0 + assert scores["lrmsd"] == 0.0 + assert scores["fnat"] == 1.0 + assert scores["dockq"] == 1.0 + assert scores["binary"] # True + assert scores["capri_class"] == 1 if __name__ == "__main__": diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/test_buildgraph.py b/tests/utils/test_buildgraph.py index 73d4d3be3..11c8b989b 100644 --- a/tests/utils/test_buildgraph.py +++ b/tests/utils/test_buildgraph.py @@ -2,9 +2,7 @@ from deeprank2.domain.aminoacidlist import valine from deeprank2.molstruct.atom import AtomicElement -from deeprank2.utils.buildgraph import (get_residue_contact_pairs, - get_structure, - get_surrounding_residues) +from deeprank2.utils.buildgraph import get_residue_contact_pairs, get_structure, get_surrounding_residues def test_get_structure_complete(): @@ -14,7 +12,7 @@ def test_get_structure_complete(): try: structure = get_structure(pdb, "101M") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) assert structure is not None @@ -42,7 +40,7 @@ def test_get_structure_from_nmr_with_dna(): try: structure = get_structure(pdb, "101M") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) assert structure is not None assert structure.chains[0].residues[0].amino_acid is None # DNA @@ -54,7 +52,7 @@ def test_residue_contact_pairs(): try: structure = get_structure(pdb, "1ATN") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) residue_pairs = get_residue_contact_pairs(pdb_path, structure, "A", "B", 8.5) assert len(residue_pairs) > 0 @@ -66,11 +64,11 @@ def test_surrounding_residues(): try: structure = get_structure(pdb, "101M") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) all_residues = structure.get_chain("A").residues # A nicely centered residue - residue = [r for r in all_residues if r.number == 138][0] + residue = next(r for r in all_residues if r.number == 138) close_residues = get_surrounding_residues(structure, residue, 10.0) assert len(close_residues) > 0, "no close residues found" diff --git a/tests/utils/test_community_pooling.py b/tests/utils/test_community_pooling.py index 11589b35d..f362dab3e 100644 --- a/tests/utils/test_community_pooling.py +++ b/tests/utils/test_community_pooling.py @@ -2,29 +2,28 @@ import numpy as np import torch -from deeprank2.utils.community_pooling import (community_detection, - community_detection_per_batch, - community_pooling) from torch_geometric.data import Batch, Data +from deeprank2.utils.community_pooling import ( + community_detection, + community_detection_per_batch, + community_pooling, +) + class TestCommunity(unittest.TestCase): def setUp(self): - self.edge_index = torch.tensor( - [[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long - ) + self.edge_index = torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]], dtype=torch.long) self.x = torch.tensor([[0], [1], [2], [3], [4], [5]], dtype=torch.float) self.data = Data(x=self.x, edge_index=self.edge_index) - self.data.pos = torch.tensor(np.random.rand(self.data.num_nodes, 3)) + self.data.pos = torch.tensor(np.random.rand(self.data.num_nodes, 3)) # noqa: NPY002 (legacy numpy code) def test_detection_mcl(self): community_detection(self.data.edge_index, self.data.num_nodes, method="mcl") def test_detection_louvain(self): - community_detection( - self.data.edge_index, self.data.num_nodes, method="louvain" - ) + community_detection(self.data.edge_index, self.data.num_nodes, method="louvain") @unittest.expectedFailure def test_detection_error(self): @@ -59,7 +58,6 @@ def test_detection_per_batch_louvain2(self): ) def test_pooling(self): - batch = Batch().from_data_list([self.data, self.data]) cluster = community_detection(batch.edge_index, batch.num_nodes) diff --git a/tests/utils/test_earlystopping.py b/tests/utils/test_earlystopping.py index e170d9752..dee4a9328 100644 --- a/tests/utils/test_earlystopping.py +++ b/tests/utils/test_earlystopping.py @@ -1,14 +1,20 @@ from deeprank2.utils.earlystopping import EarlyStopping -dummy_val_losses = [3,2,1,2,0.5,2,3,4,5,6,7] -dummy_train_losses = [3,2,1,2,0.5,2,3,4,5,1,7] +dummy_val_losses = [3, 2, 1, 2, 0.5, 2, 3, 4, 5, 6, 7] +dummy_train_losses = [3, 2, 1, 2, 0.5, 2, 3, 4, 5, 1, 7] + def base_earlystopper(patience=10, delta=0, maxgap=None): - early_stopping = EarlyStopping(patience=patience, delta=delta, maxgap=maxgap, min_epoch=0) + early_stopping = EarlyStopping( + patience=patience, + delta=delta, + maxgap=maxgap, + min_epoch=0, + ) for ep, loss in enumerate(dummy_val_losses): # check early stopping criteria - print (f'Epoch #{ep}', end=': ') + print(f"Epoch #{ep}", end=": ") early_stopping(ep, loss, dummy_train_losses[ep]) if early_stopping.early_stop: break @@ -21,6 +27,7 @@ def test_patience(): final_ep = base_earlystopper(patience=patience) assert final_ep == 7 + def test_patience_with_delta(): patience = 3 delta = 1 diff --git a/tests/utils/test_exporters.py b/tests/utils/test_exporters.py index 184901263..d49ffdd93 100644 --- a/tests/utils/test_exporters.py +++ b/tests/utils/test_exporters.py @@ -8,10 +8,7 @@ import h5py import pandas as pd -from deeprank2.utils.exporters import (HDF5OutputExporter, - OutputExporterCollection, - ScatterPlotExporter, - TensorboardBinaryClassificationExporter) +from deeprank2.utils.exporters import HDF5OutputExporter, OutputExporterCollection, ScatterPlotExporter, TensorboardBinaryClassificationExporter logging.getLogger(__name__) @@ -40,7 +37,14 @@ def test_collection(self): loss = 0.1 with collection: - collection.process(pass_name, epoch_number, entry_names, outputs, targets, loss) + collection.process( + pass_name, + epoch_number, + entry_names, + outputs, + targets, + loss, + ) assert len(os.listdir(self._work_dir)) == 2 # tensorboard & table @@ -56,7 +60,7 @@ def test_tensorboard_binary_classif(self, mock_add_scalar): targets = [0, 1, 1] loss = 0.1 - def _check_scalar(name, scalar, timestep): # pylint: disable=unused-argument + def _check_scalar(name, scalar, timestep): # noqa: ARG001 (unused argument) if name == f"{pass_name} cross entropy loss": assert scalar < 1.0 else: @@ -65,9 +69,7 @@ def _check_scalar(name, scalar, timestep): # pylint: disable=unused-argument mock_add_scalar.side_effect = _check_scalar with tensorboard_exporter: - tensorboard_exporter.process( - pass_name, epoch_number, entry_names, outputs, targets, loss - ) + tensorboard_exporter.process(pass_name, epoch_number, entry_names, outputs, targets, loss) assert mock_add_scalar.called def test_scatter_plot(self): @@ -82,7 +84,7 @@ def test_scatter_plot(self): ["entry1", "entry1", "entry2"], [0.1, 0.65, 0.98], [0.0, 0.5, 1.0], - 0.1 + 0.1, ) scatterplot_exporter.process( @@ -91,14 +93,14 @@ def test_scatter_plot(self): ["entryA", "entryB", "entryC"], [0.3, 0.35, 0.25], [0.0, 0.5, 1.0], - 0.1 + 0.1, ) assert os.path.isfile(scatterplot_exporter.get_filename(epoch_number)) - def test_hdf5_output(self): # pylint: disable=too-many-locals + def test_hdf5_output(self): output_exporter = HDF5OutputExporter(self._work_dir) - path_output_exporter = os.path.join(self._work_dir, 'output_exporter.hdf5') + path_output_exporter = os.path.join(self._work_dir, "output_exporter.hdf5") entry_names = ["entry1", "entry2", "entry3"] outputs = [[0.2, 0.1], [0.3, 0.8], [0.8, 0.9]] targets = [0, 1, 1] @@ -108,26 +110,18 @@ def test_hdf5_output(self): # pylint: disable=too-many-locals n_epoch_1 = 10 with output_exporter: for epoch_number in range(n_epoch_1): - output_exporter.process( - pass_name_1, epoch_number, entry_names, outputs, targets, loss - ) + output_exporter.process(pass_name_1, epoch_number, entry_names, outputs, targets, loss) pass_name_2 = "test_2" n_epoch_2 = 5 with output_exporter: for epoch_number in range(n_epoch_2): - output_exporter.process( - pass_name_2, epoch_number, entry_names, outputs, targets, loss - ) - - df_test_1 = pd.read_hdf( - path_output_exporter, - key=pass_name_1) - df_test_2 = pd.read_hdf( - path_output_exporter, - key=pass_name_2) - - df_hdf5 = h5py.File(path_output_exporter,'r') + output_exporter.process(pass_name_2, epoch_number, entry_names, outputs, targets, loss) + + df_test_1 = pd.read_hdf(path_output_exporter, key=pass_name_1) + df_test_2 = pd.read_hdf(path_output_exporter, key=pass_name_2) + + df_hdf5 = h5py.File(path_output_exporter, "r") df_keys = list(df_hdf5.keys()) df_keys.sort() # assert that the hdf5 output file contains exactly 2 Groups, test_1 and test_2 @@ -140,11 +134,11 @@ def test_hdf5_output(self): # pylint: disable=too-many-locals assert list(df_test_1.entry.unique()) == entry_names assert list(df_test_2.entry.unique()) == entry_names # assert there are len(entry_names) rows for each epoch - assert df_test_1[df_test_1.phase == pass_name_1].groupby(['epoch'], as_index=False).count().phase.unique() == len(entry_names) - assert df_test_2[df_test_2.phase == pass_name_2].groupby(['epoch'], as_index=False).count().phase.unique() == len(entry_names) + assert df_test_1[df_test_1.phase == pass_name_1].groupby(["epoch"], as_index=False).count().phase.unique() == len(entry_names) + assert df_test_2[df_test_2.phase == pass_name_2].groupby(["epoch"], as_index=False).count().phase.unique() == len(entry_names) # assert there are len(entry_names)*n_epoch rows - assert df_test_1[df_test_1.phase == pass_name_1].shape[0] == len(entry_names)*n_epoch_1 - assert df_test_2[df_test_2.phase == pass_name_2].shape[0] == len(entry_names)*n_epoch_2 + assert df_test_1[df_test_1.phase == pass_name_1].shape[0] == len(entry_names) * n_epoch_1 + assert df_test_2[df_test_2.phase == pass_name_2].shape[0] == len(entry_names) * n_epoch_2 # assert there are 6 columns ('phase', 'epoch', 'entry', 'output', 'target', 'loss') assert df_test_1[df_test_1.phase == pass_name_1].shape[1] == 6 assert df_test_2[df_test_2.phase == pass_name_2].shape[1] == 6 diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 8d5563f1e..2e2c4bd05 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -27,17 +27,15 @@ target_value = 1.0 -@pytest.fixture +@pytest.fixture() def graph(): - """Build a simple graph of two nodes and one edge in between them. - """ - + """Build a simple graph of two nodes and one edge in between them.""" # load the structure pdb = pdb2sql("tests/data/pdb/101M/101M.pdb") try: structure = get_structure(pdb, entry_id) finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) # build a contact from two residues residue0 = structure.chains[0].residues[0] @@ -62,9 +60,7 @@ def graph(): # init the graph graph = Graph(structure.id) - graph.center = np.mean( - [node0.features[Nfeat.POSITION], node1.features[Nfeat.POSITION]], - axis=0) + graph.center = np.mean([node0.features[Nfeat.POSITION], node1.features[Nfeat.POSITION]], axis=0) graph.targets[target_name] = target_value graph.add_node(node0) @@ -74,16 +70,13 @@ def graph(): def test_graph_write_to_hdf5(graph): - """Test that the graph is correctly written to hdf5 file. - """ - + """Test that the graph is correctly written to hdf5 file.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() hdf5_path = os.path.join(tmp_dir_path, "101m.hdf5") try: - # export graph to hdf5 graph.write_to_hdf5(hdf5_path) @@ -95,18 +88,15 @@ def test_graph_write_to_hdf5(graph): assert Nfeat.NODE in grp node_features_group = grp[Nfeat.NODE] assert node_feature_narray in node_features_group - assert len(np.nonzero( - node_features_group[node_feature_narray][()])) > 0 + assert len(np.nonzero(node_features_group[node_feature_narray][()])) > 0 assert node_features_group[node_feature_narray][()].shape == (2, 3) - assert node_features_group[node_feature_singleton][()].shape == ( - 2, ) + assert node_features_group[node_feature_singleton][()].shape == (2,) # edges assert Efeat.EDGE in grp edge_features_group = grp[Efeat.EDGE] assert edge_feature_narray in edge_features_group - assert len(np.nonzero( - edge_features_group[edge_feature_narray][()])) > 0 + assert len(np.nonzero(edge_features_group[edge_feature_narray][()])) > 0 assert edge_features_group[edge_feature_narray][()].shape == (1, 1) assert Efeat.INDEX in edge_features_group assert len(np.nonzero(edge_features_group[Efeat.INDEX][()])) > 0 @@ -119,22 +109,18 @@ def test_graph_write_to_hdf5(graph): def test_graph_write_as_grid_to_hdf5(graph): - """Test that the graph is correctly written to hdf5 file as a grid. - """ - + """Test that the graph is correctly written to hdf5 file as a grid.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() hdf5_path = os.path.join(tmp_dir_path, "101m.hdf5") try: - # export grid to hdf5 grid_settings = GridSettings([20, 20, 20], [20.0, 20.0, 20.0]) assert np.all(grid_settings.resolutions == np.array((1.0, 1.0, 1.0))) - graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, - MapMethod.GAUSSIAN) + graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.GAUSSIAN) # check the contents of the hdf5 file with h5py.File(hdf5_path, "r") as f5: @@ -145,13 +131,12 @@ def test_graph_write_as_grid_to_hdf5(graph): mapped_group = grp[gridstorage.MAPPED_FEATURES] ## narray features for feature_name in [ - f"{node_feature_narray}_000", f"{node_feature_narray}_001", - f"{node_feature_narray}_002", f"{edge_feature_narray}_000" + f"{node_feature_narray}_000", + f"{node_feature_narray}_001", + f"{node_feature_narray}_002", + f"{edge_feature_narray}_000", ]: - - assert ( - feature_name - in mapped_group), f"missing mapped feature {feature_name}" + assert feature_name in mapped_group, f"missing mapped feature {feature_name}" data = mapped_group[feature_name][()] assert len(np.nonzero(data)) > 0, f"{feature_name}: all zero" assert np.all(data.shape == tuple(grid_settings.points_counts)) @@ -168,38 +153,31 @@ def test_graph_write_as_grid_to_hdf5(graph): def test_graph_augmented_write_as_grid_to_hdf5(graph): - """Test that the graph is correctly written to hdf5 file as a grid. - """ - + """Test that the graph is correctly written to hdf5 file as a grid.""" # create a temporary hdf5 file to write to tmp_dir_path = tempfile.mkdtemp() hdf5_path = os.path.join(tmp_dir_path, "101m.hdf5") try: - # export grid to hdf5 grid_settings = GridSettings([20, 20, 20], [20.0, 20.0, 20.0]) assert np.all(grid_settings.resolutions == np.array((1.0, 1.0, 1.0))) # save to hdf5 - graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, - MapMethod.GAUSSIAN) + graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.GAUSSIAN) # two data points augmentation axis, angle = get_rot_axis_angle(randrange(100)) augmentation = Augmentation(axis, angle) - graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, - MapMethod.GAUSSIAN, augmentation) + graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.GAUSSIAN, augmentation) axis, angle = get_rot_axis_angle(randrange(100)) augmentation = Augmentation(axis, angle) - graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, - MapMethod.GAUSSIAN, augmentation) + graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.GAUSSIAN, augmentation) # check the contents of the hdf5 file with h5py.File(hdf5_path, "r") as f5: - assert list( - f5.keys()) == [entry_id, f"{entry_id}_000", f"{entry_id}_001"] + assert list(f5.keys()) == [entry_id, f"{entry_id}_000", f"{entry_id}_001"] grp = f5[entry_id] mapped_group = grp[gridstorage.MAPPED_FEATURES] # check that the feature value is preserved after augmentation @@ -213,32 +191,25 @@ def test_graph_augmented_write_as_grid_to_hdf5(graph): mapped_group = grp[gridstorage.MAPPED_FEATURES] ## narray features for feature_name in [ - f"{node_feature_narray}_000", - f"{node_feature_narray}_001", - f"{node_feature_narray}_002", - f"{edge_feature_narray}_000" + f"{node_feature_narray}_000", + f"{node_feature_narray}_001", + f"{node_feature_narray}_002", + f"{edge_feature_narray}_000", ]: - - assert (feature_name in mapped_group - ), f"missing mapped feature {feature_name}" + assert feature_name in mapped_group, f"missing mapped feature {feature_name}" data = mapped_group[feature_name][()] - assert len( - np.nonzero(data)) > 0, f"{feature_name}: all zero" - assert np.all( - data.shape == tuple(grid_settings.points_counts)) + assert len(np.nonzero(data)) > 0, f"{feature_name}: all zero" + assert np.all(data.shape == tuple(grid_settings.points_counts)) ## single value features data = mapped_group[node_feature_singleton][()] assert len(np.nonzero(data)) > 0, f"{feature_name}: all zero" assert np.all(data.shape == tuple(grid_settings.points_counts)) # check that the augmented data is the same, just different orientation - assert (f5[f"{entry_id}/grid_points/center"][( - )] == f5[f"{aug_id}/grid_points/center"][()]).all() - assert np.abs(np.sum(data) - - np.sum(unaugmented_data)).item() < 0.2 + assert (f5[f"{entry_id}/grid_points/center"][()] == f5[f"{aug_id}/grid_points/center"][()]).all() + assert np.abs(np.sum(data) - np.sum(unaugmented_data)).item() < 0.2 # target - assert grp[Target.VALUES][target_name][( - )] == target_value + assert grp[Target.VALUES][target_name][()] == target_value finally: shutil.rmtree(tmp_dir_path) # clean up after the test diff --git a/tests/utils/test_grid.py b/tests/utils/test_grid.py index 5e666f3e5..c1d53eec4 100644 --- a/tests/utils/test_grid.py +++ b/tests/utils/test_grid.py @@ -11,7 +11,7 @@ def test_grid_orientation(): grid_sizes = [30.0, 30.0, 30.0] # Extract data from original deeprank's preprocessed file. - with h5py.File("tests/data/hdf5/original-deeprank-1ak4.hdf5", 'r') as data_file: + with h5py.File("tests/data/hdf5/original-deeprank-1ak4.hdf5", "r") as data_file: grid_points_group = data_file["1AK4/grid_points"] target_xs = grid_points_group["x"][()] target_ys = grid_points_group["y"][()] @@ -23,7 +23,7 @@ def test_grid_orientation(): query = ProteinProteinInterfaceQuery( pdb_path="tests/data/pdb/1ak4/1ak4.pdb", resolution=resolution, - chain_ids=['C', 'D'], + chain_ids=["C", "D"], influence_radius=8.5, max_edge_length=8.5, ) diff --git a/tests/utils/test_pssmdata.py b/tests/utils/test_pssmdata.py index 97315fe50..cc25e0b4c 100644 --- a/tests/utils/test_pssmdata.py +++ b/tests/utils/test_pssmdata.py @@ -1,7 +1,8 @@ +from pdb2sql import pdb2sql + from deeprank2.domain.aminoacidlist import alanine from deeprank2.utils.buildgraph import get_structure from deeprank2.utils.parsing.pssm import parse_pssm -from pdb2sql import pdb2sql def test_add_pssm(): @@ -9,10 +10,10 @@ def test_add_pssm(): try: structure = get_structure(pdb, "1ATN") finally: - pdb._close() # pylint: disable=protected-access + pdb._close() # noqa: SLF001 (private member accessed) for chain in structure.chains: - with open(f"tests/data/pssm/1ATN/1ATN.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: + with open(f"tests/data/pssm/1ATN/1ATN.{chain.id}.pdb.pssm", encoding="utf-8") as f: chain.pssm = parse_pssm(f, chain) # Verify that each residue is present and that the data makes sense: From c7bbad282c8e5ba28ff79340699b8eea5ebeae09 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 08:29:39 +0100 Subject: [PATCH 5/7] style: check commented out code --- deeprank2/neuralnets/gnn/foutnet.py | 8 ------- deeprank2/neuralnets/gnn/sgat.py | 3 --- deeprank2/trainer.py | 4 ---- deeprank2/utils/community_pooling.py | 1 - tests/features/test_secondary_structure.py | 4 ++-- tests/perf/ppi_perf.py | 4 ++-- tests/perf/srv_perf.py | 4 ++-- tests/test_set_lossfunction.py | 28 ++++++++++++---------- tests/utils/test_earlystopping.py | 8 +++---- 9 files changed, 24 insertions(+), 40 deletions(-) diff --git a/deeprank2/neuralnets/gnn/foutnet.py b/deeprank2/neuralnets/gnn/foutnet.py index 216ef6b9d..2ae898860 100644 --- a/deeprank2/neuralnets/gnn/foutnet.py +++ b/deeprank2/neuralnets/gnn/foutnet.py @@ -48,11 +48,7 @@ def reset_parameters(self): def forward(self, x, edge_index): num_node = len(x) - - # alpha = x * Wc alpha = torch.mm(x, self.wc) - - # beta = x * Wn beta = torch.mm(x, self.wn) # gamma_i = 1/Ni Sum_j x_j * Wn @@ -62,7 +58,6 @@ def forward(self, x, edge_index): index = edge_index[:, edge_index[0, :] == n][1, :] gamma[n, :] = torch.mean(beta[index, :], dim=0) - # alpha = alpha + gamma alpha = alpha + gamma # add the bias @@ -95,7 +90,6 @@ def __init__( def forward(self, data): act = nn.Tanhshrink() act = F.relu - # act = nn.LeakyReLU(0.25) # first conv block data.x = act(self.conv1(data.x, data.edge_index)) @@ -111,7 +105,5 @@ def forward(self, data): x = scatter_mean(x, batch, dim=0) x = act(self.fc1(x)) x = self.fc2(x) - # x = F.dropout(x, training=self.training) return x # noqa:RET504 (unnecessary-assign) - # return F.relu(x) diff --git a/deeprank2/neuralnets/gnn/sgat.py b/deeprank2/neuralnets/gnn/sgat.py index e321ff465..b594181cc 100644 --- a/deeprank2/neuralnets/gnn/sgat.py +++ b/deeprank2/neuralnets/gnn/sgat.py @@ -100,7 +100,6 @@ def __init__( def forward(self, data): act = nn.Tanhshrink() act = F.relu - # act = nn.LeakyReLU(0.25) # first conv block data.x = act(self.conv1(data.x, data.edge_index, data.edge_attr)) @@ -116,7 +115,5 @@ def forward(self, data): x = scatter_mean(x, batch, dim=0) x = act(self.fc1(x)) x = self.fc2(x) - # x = F.dropout(x, training=self.training) return x # noqa:RET504 (unnecessary-assign) - # return F.relu(x) diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 430e6cf01..3a5ef32d3 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -813,10 +813,6 @@ def _format_output(self, pred, target=None): target = torch.tensor([self.classes_to_index[x] if isinstance(x, str) else self.classes_to_index[int(x)] for x in target]) if isinstance(self.lossfunction, nn.BCELoss | nn.BCEWithLogitsLoss): # # pred must be in (0,1) range and target must be float with same shape as pred - # pred = F.softmax(pred) - # target = torch.tensor( - # [[0,1] if x == [1] else [1,0] for x in target] - # ).float() raise ValueError( "BCELoss and BCEWithLogitsLoss are currently not supported.\n\t" "For further details see: https://github.com/DeepRank/deeprank2/issues/318" diff --git a/deeprank2/utils/community_pooling.py b/deeprank2/utils/community_pooling.py index 14b8e1de7..747f77c32 100644 --- a/deeprank2/utils/community_pooling.py +++ b/deeprank2/utils/community_pooling.py @@ -86,7 +86,6 @@ def community_detection_per_batch( else: raise ValueError(f"Clustering method {method} not supported") - # return device = edge_index.device return torch.tensor(cluster).to(device) diff --git a/tests/features/test_secondary_structure.py b/tests/features/test_secondary_structure.py index 3af81666b..5c347b597 100644 --- a/tests/features/test_secondary_structure.py +++ b/tests/features/test_secondary_structure.py @@ -33,12 +33,12 @@ def test_secondary_structure_residue(): (267, "A", " ", SecondarySctructure.COIL), (46, "A", "S", SecondarySctructure.COIL), (104, "A", "T", SecondarySctructure.COIL), - # (None, '', 'P', SecondarySctructure.COIL), # not found in test file + # (None, '', 'P', SecondarySctructure.COIL), # not found in test file # noqa: ERA001 (commented-out code) (194, "A", "B", SecondarySctructure.STRAND), (385, "B", "E", SecondarySctructure.STRAND), (235, "A", "G", SecondarySctructure.HELIX), (263, "A", "H", SecondarySctructure.HELIX), - # (0, '', 'I', SecondarySctructure.HELIX), # not found in test file + # (0, '', 'I', SecondarySctructure.HELIX), # not found in test file # noqa: ERA001 (commented-out code) ] for res in residues: diff --git a/tests/perf/ppi_perf.py b/tests/perf/ppi_perf.py index 36b2f73cd..a184cc606 100644 --- a/tests/perf/ppi_perf.py +++ b/tests/perf/ppi_perf.py @@ -28,8 +28,8 @@ sizes=[1.0, 1.0, 1.0], ) grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids -# grid_settings = None -# grid_map_method = None +# grid_settings = None # noqa: ERA001 (commented out code) +# grid_map_method = None # noqa: ERA001 (commented out code) feature_modules = [components, contact, exposure, irc, secondary_structure, surfacearea] cpu_count = 1 #################################################### diff --git a/tests/perf/srv_perf.py b/tests/perf/srv_perf.py index dd1d11f65..6e358a215 100644 --- a/tests/perf/srv_perf.py +++ b/tests/perf/srv_perf.py @@ -74,8 +74,8 @@ sizes=[1.0, 1.0, 1.0], ) grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids -# grid_settings = None -# grid_map_method = None +# grid_settings = None # noqa: ERA001 (commented out code) +# grid_map_method = None # noqa: ERA001 (commented out code) feature_modules = [components, contact, exposure, irc, surfacearea, secondary_structure] cpu_count = 1 #################################################### diff --git a/tests/test_set_lossfunction.py b/tests/test_set_lossfunction.py index f495d0f0e..191d54820 100644 --- a/tests/test_set_lossfunction.py +++ b/tests/test_set_lossfunction.py @@ -94,19 +94,21 @@ def test_classif_weighted(self): assert isinstance(trainer_pretrained.lossfunction, lossfunction) assert trainer_pretrained.class_weights - # def test_classif_invalid_weighted(self): - # dataset = GraphDataset(hdf5_path, - # target=targets.BINARY) - # trainer = Trainer( - # neuralnet = NaiveNetwork, - # dataset_train = dataset, - # class_weights = True - # ) - # # use a loss function that does not allow for weighted loss, e.g. MultiLabelMarginLoss - # lossfunction = nn.MultiLabelMarginLoss - - # with pytest.raises(ValueError): - # base_test(self.save_path, trainer, lossfunction) + def test_classif_invalid_weighted(self): + dataset = GraphDataset( + hdf5_path, + target=targets.BINARY, + ) + trainer = Trainer( + neuralnet=NaiveNetwork, + dataset_train=dataset, + class_weights=True, + ) + # use a loss function that does not allow for weighted loss, e.g. MultiLabelMarginLoss + lossfunction = nn.MultiLabelMarginLoss + + with pytest.raises(ValueError): + base_test(self.save_path, trainer, lossfunction) def test_classif_invalid_lossfunction(self): dataset = GraphDataset( diff --git a/tests/utils/test_earlystopping.py b/tests/utils/test_earlystopping.py index dee4a9328..6c26198ea 100644 --- a/tests/utils/test_earlystopping.py +++ b/tests/utils/test_earlystopping.py @@ -25,6 +25,7 @@ def base_earlystopper(patience=10, delta=0, maxgap=None): def test_patience(): patience = 3 final_ep = base_earlystopper(patience=patience) + # should terminate at epoch 7 assert final_ep == 7 @@ -32,15 +33,12 @@ def test_patience_with_delta(): patience = 3 delta = 1 final_ep = base_earlystopper(patience=patience, delta=delta) + # should terminate at epoch 5 assert final_ep == 5 def test_maxgap(): maxgap = 1 final_ep = base_earlystopper(maxgap=maxgap) + # should terminate at epoch 9 assert final_ep == 9 - - -# test_patience() # should terminate at epoch 7 -# test_patience_with_delta() # should terminate at epoch 5 -# test_maxgap() # should terminate at epoch 9 From 054c7d3cd0c4b951d4efe5da921b0baa5a5360ce Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 08:11:58 +0100 Subject: [PATCH 6/7] test: suppress some warnings --- deeprank2/dataset.py | 8 +++-- tests/features/test_secondary_structure.py | 6 +++- tests/tools/test_target.py | 39 ++++++++++++---------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 0aace279f..02a6ba4a0 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -430,9 +430,11 @@ def save_hist( else: raise ValueError("Please provide valid features names. They must be present in the current :class:`DeeprankDataset` children instance.") - fig.tight_layout() - fig.savefig(fname) - plt.close(fig) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fig.tight_layout() + fig.savefig(fname) + plt.close(fig) def _compute_mean_std(self): means = { diff --git a/tests/features/test_secondary_structure.py b/tests/features/test_secondary_structure.py index 5c347b597..7d2ff1ff0 100644 --- a/tests/features/test_secondary_structure.py +++ b/tests/features/test_secondary_structure.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from deeprank2.domain import nodestorage as Nfeat @@ -19,7 +21,9 @@ def test_secondary_structure_residue(): influence_radius=10, max_edge_length=10, ) - add_features(pdb_path, graph) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + add_features(pdb_path, graph) # Create a list of node information (residue number, chain ID, and secondary structure features) node_info_list = [[node.id.number, node.id.chain.id, node.features[Nfeat.SECSTRUCT]] for node in graph.nodes] diff --git a/tests/tools/test_target.py b/tests/tools/test_target.py index 32dac919a..577b3aad7 100644 --- a/tests/tools/test_target.py +++ b/tests/tools/test_target.py @@ -2,6 +2,7 @@ import shutil import tempfile import unittest +import warnings from pdb2sql import StructureSimilarity @@ -34,25 +35,27 @@ def test_add_target(self): os.remove(graph_path) def test_compute_ppi_scores(self): - scores = compute_ppi_scores( - os.path.join(self.pdb_path, "1ATN_1w.pdb"), - os.path.join(self.ref, "1ATN.pdb"), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") - sim = StructureSimilarity( - os.path.join(self.pdb_path, "1ATN_1w.pdb"), - os.path.join(self.ref, "1ATN.pdb"), - enforce_residue_matching=False, - ) - lrmsd = sim.compute_lrmsd_fast(method="svd") - irmsd = sim.compute_irmsd_fast(method="svd") - fnat = sim.compute_fnat_fast() - dockq = sim.compute_DockQScore(fnat, lrmsd, irmsd) - binary = irmsd < 4.0 - capri = 4 - for thr, val in zip([6.0, 4.0, 2.0, 1.0], [4, 3, 2, 1], strict=True): - if irmsd < thr: - capri = val + scores = compute_ppi_scores( + os.path.join(self.pdb_path, "1ATN_1w.pdb"), + os.path.join(self.ref, "1ATN.pdb"), + ) + sim = StructureSimilarity( + os.path.join(self.pdb_path, "1ATN_1w.pdb"), + os.path.join(self.ref, "1ATN.pdb"), + enforce_residue_matching=False, + ) + lrmsd = sim.compute_lrmsd_fast(method="svd") + irmsd = sim.compute_irmsd_fast(method="svd") + fnat = sim.compute_fnat_fast() + dockq = sim.compute_DockQScore(fnat, lrmsd, irmsd) + binary = irmsd < 4.0 + capri = 4 + for thr, val in zip([6.0, 4.0, 2.0, 1.0], [4, 3, 2, 1], strict=True): + if irmsd < thr: + capri = val assert scores["irmsd"] == irmsd assert scores["lrmsd"] == lrmsd From 7dc252669b02a009f3c35bb55f0940c2a236b4c6 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Tue, 16 Jan 2024 09:11:11 +0100 Subject: [PATCH 7/7] style: format non-python files --- .github/ISSUE_TEMPLATE/bug_report.md | 22 +- .github/ISSUE_TEMPLATE/feature_request.md | 10 +- .../install-python-and-package/action.yml | 8 +- .github/workflows/build.yml | 4 +- .github/workflows/cffconvert.yml | 2 +- .github/workflows/coveralls.yml | 2 +- .github/workflows/draft-pdf.yml | 2 +- .github/workflows/fair-software.yml | 2 +- .github/workflows/linting.yml | 2 +- .github/workflows/markdown-link-check.yml | 10 +- .github/workflows/release.yml | 4 +- .github/workflows/stale_issue_pr.yml | 2 +- .readthedocs.yaml | 2 +- .vscode/settings.json | 34 +- .zenodo.json | 122 +++--- CHANGELOG.md | 245 +++++++------ CITATION.cff | 18 +- README.dev.md | 12 +- README.md | 54 +-- docs/docking.md | 4 +- docs/features.md | 20 +- docs/getstarted.md | 7 +- docs/index.rst | 6 +- docs/installation.md | 27 +- docs/requirements.txt | 2 +- paper/paper.bib | 2 +- paper/paper.md | 21 +- tests/data/hdf5/_generate_testdata.ipynb | 203 +++++----- tutorials/TUTORIAL.md | 7 +- tutorials/data_generation_ppi.ipynb | 211 +++++------ tutorials/data_generation_srv.ipynb | 248 +++++++------ tutorials/training.ipynb | 347 +++++++++--------- 32 files changed, 853 insertions(+), 809 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 63264d4e7..baf8836c3 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,29 +1,29 @@ --- name: Bug report about: Create a report to help us improve -title: 'Bug: ' -labels: 'bug' -assignees: '' - +title: "Bug: " +labels: "bug" +assignees: "" --- **Describe the bug** A clear and concise description of what the bug is. **Environment:** + - OS system: - Version: -- Branch commit ID: +- Branch commit ID: - Inputs: **To Reproduce** Steps/commands/screenshots to reproduce the behaviour: - - 1. - - 2. - - 3. + +1. + +2. + +3. **Expected Results** A clear and concise description of what you expected to happen. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 457d519dc..174d8b991 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,10 +1,9 @@ --- name: Feature request about: Suggest an idea for this project -title: 'Add/edit' -labels: 'feature' -assignees: '' - +title: "Add/edit" +labels: "feature" +assignees: "" --- **Is your feature request related to a problem? Please describe.** @@ -20,7 +19,8 @@ A clear and concise description of any alternative solutions or features you've Add any other context or screenshots about the feature request here. **Checks for the developer** -After having implemented the request, please remember to: +After having implemented the request, please remember to: + - [ ] Add all the necessary tests. Make sure that the parameter functionality is well tested, from all points of views. - [ ] Add the proper documentation to the source code (docstrings). - [ ] Add the proper documentation to the readme. Examples about how using the new feature should be clear and easy to follow. diff --git a/.github/actions/install-python-and-package/action.yml b/.github/actions/install-python-and-package/action.yml index fd9f1a206..cd0089675 100644 --- a/.github/actions/install-python-and-package/action.yml +++ b/.github/actions/install-python-and-package/action.yml @@ -3,7 +3,6 @@ name: "Install Python and deeprank2" description: "Installs Python, updates pip and installs deeprank2 together with its dependencies." inputs: - python-version: required: false description: "The Python version to use. Specify major and minor version, e.g. '3.10'." @@ -15,14 +14,13 @@ inputs: default: "test" runs: - using: "composite" steps: - name: Cancel Previous Runs and Set up Python uses: styfle/cancel-workflow-action@0.4.0 with: - access_token: ${{ github.token }} + access_token: ${{ github.token }} - uses: actions/checkout@v3 - name: Setup conda uses: s-weigand/setup-conda@v1 @@ -43,7 +41,7 @@ runs: CMAKE_INSTALL_PREFIX: .local if: runner.os == 'Linux' run: | - # Install dependencies not handled by setuptools + # Install dependencies not handled by setuptools ## DSSP sudo apt-get install -y dssp ## MSMS @@ -59,7 +57,7 @@ runs: CMAKE_INSTALL_PREFIX: .local if: runner.os == 'macOS' run: | - # Install dependencies not handled by setuptools + # Install dependencies not handled by setuptools ## DSSP git clone https://github.com/PDB-REDO/libcifpp.git --recurse-submodules cd libcifpp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6b1ad6ee3..4eb534ded 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,8 +36,8 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest'] - python-version: ['3.10', '3.11'] + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 75a06e34b..6851c52d3 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -16,4 +16,4 @@ jobs: - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@2.0.0 with: - args: "--validate" \ No newline at end of file + args: "--validate" diff --git a/.github/workflows/coveralls.yml b/.github/workflows/coveralls.yml index a59459c20..ba0f5a391 100644 --- a/.github/workflows/coveralls.yml +++ b/.github/workflows/coveralls.yml @@ -37,7 +37,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest"] - python-version: ['3.10'] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml index ae471e46b..8b5159f73 100644 --- a/.github/workflows/draft-pdf.yml +++ b/.github/workflows/draft-pdf.yml @@ -23,4 +23,4 @@ jobs: # This is the output path where Pandoc will write the compiled # PDF. Note, this should be the same directory as the input # paper.md - path: paper/paper.pdf \ No newline at end of file + path: paper/paper.pdf diff --git a/.github/workflows/fair-software.yml b/.github/workflows/fair-software.yml index e5384468b..f20d3c846 100644 --- a/.github/workflows/fair-software.yml +++ b/.github/workflows/fair-software.yml @@ -5,7 +5,7 @@ on: branches: - main pull_request: - types: [opened, synchronize, reopened, ready_for_review] + types: [opened, synchronize, reopened, ready_for_review] jobs: verify: diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 8f78bb8b7..15b26684f 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -37,7 +37,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest"] - python-version: ['3.10'] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/markdown-link-check.yml b/.github/workflows/markdown-link-check.yml index 4756d1a86..016e38120 100644 --- a/.github/workflows/markdown-link-check.yml +++ b/.github/workflows/markdown-link-check.yml @@ -3,7 +3,7 @@ name: markdown-link-check on: push: branches: - - main + - main paths: # filetypes - "**.md" @@ -25,7 +25,7 @@ jobs: name: Check markdown links runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: gaurav-nelson/github-action-markdown-link-check@v1 - with: - config-file: '.github/workflows/markdown-link-check.yml' + - uses: actions/checkout@v3 + - uses: gaurav-nelson/github-action-markdown-link-check@v1 + with: + config-file: ".github/workflows/markdown-link-check.yml" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 71e471d85..887f160c6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,8 +14,8 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest'] - python-version: ['3.10'] + os: ["ubuntu-latest"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/stale_issue_pr.yml b/.github/workflows/stale_issue_pr.yml index bb54f4cf1..d184c9c9d 100644 --- a/.github/workflows/stale_issue_pr.yml +++ b/.github/workflows/stale_issue_pr.yml @@ -21,4 +21,4 @@ jobs: days-before-pr-close: -1 stale-pr-message: "This PR is stale because it has been open for 14 days with no activity." close-pr-message: "This PR was closed because it has been inactive for 7 days since being marked as stale." - exempt-issue-labels: 'blocked' \ No newline at end of file + exempt-issue-labels: "blocked" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 45bb2416b..a4daf4f3f 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -12,4 +12,4 @@ sphinx: # Explicitly set the version of Python and its requirements python: install: - - requirements: docs/requirements.txt \ No newline at end of file + - requirements: docs/requirements.txt diff --git a/.vscode/settings.json b/.vscode/settings.json index f0f387d5e..a519af241 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,19 +1,23 @@ { - // Python - "[python]": { - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.fixAll": "explicit" - }, - "editor.defaultFormatter": "charliermarsh.ruff" + // Python + "[python]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "explicit" }, - "autoDocstring.docstringFormat": "google", + "editor.defaultFormatter": "charliermarsh.ruff" + }, + "autoDocstring.docstringFormat": "google", - // Notebooks - "notebook.lineNumbers": "on", - "notebook.formatOnSave.enabled": true, - "notebook.codeActionsOnSave": { - "notebook.source.fixAll": "explicit", - }, - "notebook.diff.ignoreMetadata": true, + // Notebooks + "notebook.lineNumbers": "on", + "notebook.formatOnSave.enabled": true, + "notebook.codeActionsOnSave": { + "notebook.source.fixAll": "explicit" + }, + "notebook.diff.ignoreMetadata": true, + + // Format all files on save + "editor.formatOnSave": true, + "editor.defaultFormatter": "esbenp.prettier-vscode" } diff --git a/.zenodo.json b/.zenodo.json index d7d1125da..85ec62a8e 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -1,64 +1,64 @@ { - "creators": [ - { - "affiliation": "Netherlands eScience Center", - "name": "Giulia Crocioni", - "orcid": "0000-0002-0823-0121" - }, - { - "affiliation": "Netherlands eScience Center", - "name": "Dani L. Bodor", - "orcid": "0000-0003-2109-2349" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Coos Baakman", - "orcid": "0000-0003-4317-1566" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Farzaneh M. Parizi", - "orcid": "0000-0003-4230-7492" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Daniel T. Rademaker", - "orcid": "0000-0003-1959-1317" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Gayatri Ramakrishnan", - "orcid": "0000-0001-8203-2783" - }, - { - "affiliation": "Netherlands eScience Center", - "name": "Sven van der Burg", - "orcid": "0000-0003-1250-6968" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Dario F. Marzella", - "orcid": "0000-0002-0043-3055" - }, - { - "name": "João M. C. Teixeira", - "orcid": "0000-0002-9113-0622" - }, - { - "affiliation": "Radboud University Medical Center", - "name": "Li C. Xue", - "orcid": "0000-0002-2613-538X" - } - ], - "keywords": [ - "Graph neural networks", - "Convolutional neural networks", - "Protein-protein interface", - "Single-residue variant", - "DeepRank" - ], - "license": { - "id": "Apache-2.0" + "creators": [ + { + "affiliation": "Netherlands eScience Center", + "name": "Giulia Crocioni", + "orcid": "0000-0002-0823-0121" }, - "title": "DeepRank2" + { + "affiliation": "Netherlands eScience Center", + "name": "Dani L. Bodor", + "orcid": "0000-0003-2109-2349" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Coos Baakman", + "orcid": "0000-0003-4317-1566" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Farzaneh M. Parizi", + "orcid": "0000-0003-4230-7492" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Daniel T. Rademaker", + "orcid": "0000-0003-1959-1317" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Gayatri Ramakrishnan", + "orcid": "0000-0001-8203-2783" + }, + { + "affiliation": "Netherlands eScience Center", + "name": "Sven van der Burg", + "orcid": "0000-0003-1250-6968" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Dario F. Marzella", + "orcid": "0000-0002-0043-3055" + }, + { + "name": "João M. C. Teixeira", + "orcid": "0000-0002-9113-0622" + }, + { + "affiliation": "Radboud University Medical Center", + "name": "Li C. Xue", + "orcid": "0000-0002-2613-538X" + } + ], + "keywords": [ + "Graph neural networks", + "Convolutional neural networks", + "Protein-protein interface", + "Single-residue variant", + "DeepRank" + ], + "license": { + "id": "Apache-2.0" + }, + "title": "DeepRank2" } diff --git a/CHANGELOG.md b/CHANGELOG.md index afabf568f..a2732a8a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,97 +5,106 @@ ### Main changes #### Refactor -* refactor: make `preprocess` use all available feature modules as default by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/247 -* refactor: move preprocess function to `QueryDataset` class and rename by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/252 -* refactor: save preprocessed data into one .hdf5 file as default by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/250 -* refactor: clean up `GraphDataset` and `Trainer` class by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/255 -* refactor: reorganize deeprank2.utils.metrics module by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/262 -* refactor: fix `transform_sigmoid` logic and move it to `GraphDataset` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/288 -* refactor: add grid dataset class and make the trainer class work with it. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/294 -* refactor: update deprecated dataloader import by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/310 -* refactor: move tests/_utils.py to tests/__init__.py by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/322 -* refactor: delete all outputs from unit tests after run by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/324 -* refactor: test_contact.py function naming and output by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/372 -* refactor: split test contact.py by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/369 -* refactor: change __repr__ of AminoAcid to 3 letter code by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/384 -* refactor: make feature modules and tests uniform and ditch duplicate code by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/400 + +- refactor: make `preprocess` use all available feature modules as default by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/247 +- refactor: move preprocess function to `QueryDataset` class and rename by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/252 +- refactor: save preprocessed data into one .hdf5 file as default by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/250 +- refactor: clean up `GraphDataset` and `Trainer` class by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/255 +- refactor: reorganize deeprank2.utils.metrics module by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/262 +- refactor: fix `transform_sigmoid` logic and move it to `GraphDataset` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/288 +- refactor: add grid dataset class and make the trainer class work with it. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/294 +- refactor: update deprecated dataloader import by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/310 +- refactor: move tests/\_utils.py to tests/**init**.py by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/322 +- refactor: delete all outputs from unit tests after run by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/324 +- refactor: test_contact.py function naming and output by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/372 +- refactor: split test contact.py by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/369 +- refactor: change **repr** of AminoAcid to 3 letter code by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/384 +- refactor: make feature modules and tests uniform and ditch duplicate code by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/400 #### Features -* feat: improve amino acid features by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/272 -* feat: add `test_size` equivalent of `val_size` to Trainer class by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/291 -* feat: add the option to have a grid box of different x,y and z dimensions by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/292 -* feat: add early stopping to `Trainer.train` by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/303 -* feat: add hist module for plotting raw hdf5 files features distributions by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/261 -* feat: allow for different loss functions other than the default by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/313 -* feat: center the grids as in the old deeprank by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/323 -* feat: add data augmentation for grids by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/336 -* feat: insert features standardization option in`DeeprankDataset` children classes by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/326 -* feat: add log transformation option for plotting features' hist by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/389 -* feat: add inter-residue contact (IRC) node features by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/333 -* feat: add feature module for secondary structure by @DTRademaker in https://github.com/DeepRank/deeprank-core/pull/387 -* feat: use dictionary for flexibly transforming and standardizing features by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/418 + +- feat: improve amino acid features by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/272 +- feat: add `test_size` equivalent of `val_size` to Trainer class by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/291 +- feat: add the option to have a grid box of different x,y and z dimensions by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/292 +- feat: add early stopping to `Trainer.train` by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/303 +- feat: add hist module for plotting raw hdf5 files features distributions by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/261 +- feat: allow for different loss functions other than the default by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/313 +- feat: center the grids as in the old deeprank by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/323 +- feat: add data augmentation for grids by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/336 +- feat: insert features standardization option in`DeeprankDataset` children classes by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/326 +- feat: add log transformation option for plotting features' hist by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/389 +- feat: add inter-residue contact (IRC) node features by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/333 +- feat: add feature module for secondary structure by @DTRademaker in https://github.com/DeepRank/deeprank-core/pull/387 +- feat: use dictionary for flexibly transforming and standardizing features by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/418 #### Fix -* fix: list all submodules imported from deeprank2.features using pkgutil by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/263 -* fix: let `classes` argument be also categorical by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/286 -* fix: makes sure that the `map_feature` function can handle single value features. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/289 -* fix: raise exception for invalid optimizer by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/307 -* fix: `num_workers` parameter of Dataloader object by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/319 -* fix: gpu usage by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/334 -* fix: gpu and `entry_names` usage by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/335 -* fix: data generation threading locked by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/330 -* fix: `__hash__` circular dependency issue by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/341 -* fix: make sure that Grid data also has target values, like graph data by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/347 -* fix: change the internal structure of the grid data to match the graph data by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/352 -* fix: conflicts in package by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/386 -* fix: correct usage of nonbond energy for close contacts by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/368 -* fix: Incorrect number of datapoints loaded to model by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/397 -* fix: pytorch 2.0 by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/406 -* fix: covalent bonds cannot link nodes on separate branches by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/408 -* fix: `Trainer` error when only `dataset_test` and `pretrained_model` are used by @ntxxt in https://github.com/DeepRank/deeprank-core/pull/413 -* fix: check PSSMs by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/401 -* fix: only check pssms if conservation module was used by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/425 -* fix: epoch number in `test()` and test on the correct model by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/427 -* fix: convert list of arrays into arrays before converting to Pytorch tensor by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/438 + +- fix: list all submodules imported from deeprank2.features using pkgutil by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/263 +- fix: let `classes` argument be also categorical by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/286 +- fix: makes sure that the `map_feature` function can handle single value features. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/289 +- fix: raise exception for invalid optimizer by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/307 +- fix: `num_workers` parameter of Dataloader object by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/319 +- fix: gpu usage by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/334 +- fix: gpu and `entry_names` usage by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/335 +- fix: data generation threading locked by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/330 +- fix: `__hash__` circular dependency issue by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/341 +- fix: make sure that Grid data also has target values, like graph data by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/347 +- fix: change the internal structure of the grid data to match the graph data by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/352 +- fix: conflicts in package by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/386 +- fix: correct usage of nonbond energy for close contacts by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/368 +- fix: Incorrect number of datapoints loaded to model by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/397 +- fix: pytorch 2.0 by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/406 +- fix: covalent bonds cannot link nodes on separate branches by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/408 +- fix: `Trainer` error when only `dataset_test` and `pretrained_model` are used by @ntxxt in https://github.com/DeepRank/deeprank-core/pull/413 +- fix: check PSSMs by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/401 +- fix: only check pssms if conservation module was used by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/425 +- fix: epoch number in `test()` and test on the correct model by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/427 +- fix: convert list of arrays into arrays before converting to Pytorch tensor by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/438 #### Docs -* docs: add verbose arg to QueryCollection class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/267 -* docs: improve `clustering_method` description and default value by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/293 -* docs: uniform docstrings format in modules by @joyceljy -* docs: incorrect usage of Union in Optional type hints by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/370 -* docs: improve docs for default exporter and results visualization by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/414 -* docs: update feature documentations by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/419 -* docs: add instructions for `GridDataset` by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/421 -* docs: fix getstarted hierarchy by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/422 -* docs: update dssp 4 install instructions by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/437 -* docs: change `external_distance_cutoff` and `interface_distance_cutoff` to `distance_cutoff` by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/246 + +- docs: add verbose arg to QueryCollection class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/267 +- docs: improve `clustering_method` description and default value by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/293 +- docs: uniform docstrings format in modules by @joyceljy +- docs: incorrect usage of Union in Optional type hints by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/370 +- docs: improve docs for default exporter and results visualization by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/414 +- docs: update feature documentations by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/419 +- docs: add instructions for `GridDataset` by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/421 +- docs: fix getstarted hierarchy by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/422 +- docs: update dssp 4 install instructions by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/437 +- docs: change `external_distance_cutoff` and `interface_distance_cutoff` to `distance_cutoff` by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/246 #### Performances -* perf: features.contact by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/220 -* perf: suppress warnings in pytest and from PDBParser by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/249 -* perf: add try except clause to `_preprocess_one_query` method of `QueryCollection` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/264 -* perf: improve `process` speed for residue based graph building by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/274 -* perf: add `cuda` and `ngpu` parameters to the `Trainer` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/311 -* perf: accelerate indexing of HDF5 files by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/362 + +- perf: features.contact by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/220 +- perf: suppress warnings in pytest and from PDBParser by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/249 +- perf: add try except clause to `_preprocess_one_query` method of `QueryCollection` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/264 +- perf: improve `process` speed for residue based graph building by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/274 +- perf: add `cuda` and `ngpu` parameters to the `Trainer` class by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/311 +- perf: accelerate indexing of HDF5 files by @joyceljy in https://github.com/DeepRank/deeprank-core/pull/362 #### Style -* style: restructure deeprank2 package and subpackages by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/240 -* style: reorganize features/contact.py by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/260 -* style: add .vscode settings.json by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/404 + +- style: restructure deeprank2 package and subpackages by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/240 +- style: reorganize features/contact.py by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/260 +- style: add .vscode settings.json by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/404 #### Test -* test: make sure that the grid orientation is as in the original deeprank for `ProteinProteinInterfaceAtomicQuery` by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/312 -* test: check that the grid for residue-based protein-protein interfaces has the same center and orientation as in the original deeprank. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/339 -* test: improve `utils/test_graph.py` module by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/420 + +- test: make sure that the grid orientation is as in the original deeprank for `ProteinProteinInterfaceAtomicQuery` by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/312 +- test: check that the grid for residue-based protein-protein interfaces has the same center and orientation as in the original deeprank. by @cbaakman in https://github.com/DeepRank/deeprank-core/pull/339 +- test: improve `utils/test_graph.py` module by @gcroci2 in https://github.com/DeepRank/deeprank-core/pull/420 #### CI -* ci: do not close stale issues or PRs by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/327 -* ci: remove incorrect message for stale branches by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/415 -* ci: automatically check markdown links by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/433 + +- ci: do not close stale issues or PRs by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/327 +- ci: remove incorrect message for stale branches by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/415 +- ci: automatically check markdown links by @DaniBodor in https://github.com/DeepRank/deeprank-core/pull/433 ### New Contributors -* @joyceljy made their first contribution in https://github.com/DeepRank/deeprank-core/pull/361 -* @ntxxt made their first contribution in https://github.com/DeepRank/deeprank-core/pull/413 + +- @joyceljy made their first contribution in https://github.com/DeepRank/deeprank-core/pull/361 +- @ntxxt made their first contribution in https://github.com/DeepRank/deeprank-core/pull/413 **Full Changelog**: https://github.com/DeepRank/deeprank-core/compare/v1.0.0...v2.0.0 @@ -105,25 +114,25 @@ Released on Oct 24, 2022 ### Added -* `weight_decay` parameter to NeuralNet #155 -* Exporter for generating a unique .csv file containing results per epoch #151 -* Automatized testing of all available features modules #163 -* `optimizer` parameter to NeuralNet #154 -* `atom` node feature #168 +- `weight_decay` parameter to NeuralNet #155 +- Exporter for generating a unique .csv file containing results per epoch #151 +- Automatized testing of all available features modules #163 +- `optimizer` parameter to NeuralNet #154 +- `atom` node feature #168 ### Changed -* `index` parameter of NeuralNet is now called `subset` #159 -* `percent` parameter of NeuralNet is now called `val_size`, and the logic behing it has been improved #183 -* Aligned the package to PyTorch high-level frameworks #172 - * NeuralNet is now called Trainer -* Clearer features names #145 -* Changed definitions in storage.py #150 -* `MAX_COVALENT_DISTANCE` is now 2.1 instead of 3 #205 +- `index` parameter of NeuralNet is now called `subset` #159 +- `percent` parameter of NeuralNet is now called `val_size`, and the logic behing it has been improved #183 +- Aligned the package to PyTorch high-level frameworks #172 + - NeuralNet is now called Trainer +- Clearer features names #145 +- Changed definitions in storage.py #150 +- `MAX_COVALENT_DISTANCE` is now 2.1 instead of 3 #205 ### Removed -* `threshold` input parameter from NeuralNet #157 +- `threshold` input parameter from NeuralNet #157 ## 0.2.0 @@ -131,17 +140,17 @@ Released on Aug 10, 2022 ### Added -* Automatic version bumping using `bump2version` with `.bumpversion.cfg` #126 -* `cffconvert.yml` to the CI workflow #139 -* Integration test for the Machine Learning pipeline #95 -* The package now is tested also on Python 3.10 #165 +- Automatic version bumping using `bump2version` with `.bumpversion.cfg` #126 +- `cffconvert.yml` to the CI workflow #139 +- Integration test for the Machine Learning pipeline #95 +- The package now is tested also on Python 3.10 #165 ### Changed -* Test PyPI package before publishing, by triggering a `workflow_dispatch` event from the Actions tab on `release.yml` workflow file #123 -* Coveralls is now working again #124 -* Wrong Zenodo entry has been corrected #138 -* Improved CUDA support (added for data tensors) #132 +- Test PyPI package before publishing, by triggering a `workflow_dispatch` event from the Actions tab on `release.yml` workflow file #123 +- Coveralls is now working again #124 +- Wrong Zenodo entry has been corrected #138 +- Improved CUDA support (added for data tensors) #132 ## 0.1.1 @@ -149,28 +158,28 @@ Released on June 28, 2022 ### Added -* Graph class #48 -* Tensorboard #15 -* CI Linting #30 -* Name, affiliation and orcid to `.zenodo.json` #18 -* Metrics class #17 -* QueryDataset class #53 -* Unit tests for NeuralNet class #86 -* Error message if you pick the wrong metrics #110 -* Unit tests for HDF5Dataset class parameters #82 -* Installation from PyPI in the readme #122 +- Graph class #48 +- Tensorboard #15 +- CI Linting #30 +- Name, affiliation and orcid to `.zenodo.json` #18 +- Metrics class #17 +- QueryDataset class #53 +- Unit tests for NeuralNet class #86 +- Error message if you pick the wrong metrics #110 +- Unit tests for HDF5Dataset class parameters #82 +- Installation from PyPI in the readme #122 ### Changed -* `test_process()` does not fail anymore #47 -* Tests have been speded up #36 -* `multiprocessing.Queue` has been replaced with `multiprocessing.pool.map` in PreProcessor #56 -* `test_preprocess.py` does not fail anymore on Mac M1 #74 -* It's now possible to pass your own train/test split to NeuralNet class #81 -* HDF5Dataset class now is used in the UX #83 -* IndexError running `NeuralNet.train()` has been fixed #89 -* pip installation has been fixed -* Repository has been renamed deeprank-core, and the package deeprank2 #101 -* The zero-division like error from TensorboardBinaryClassificationExporter has been fixed #112 -* h5xplorer is installed through `setup.cfg` file #121 -* Sphinx docs have been fixed #108 +- `test_process()` does not fail anymore #47 +- Tests have been speded up #36 +- `multiprocessing.Queue` has been replaced with `multiprocessing.pool.map` in PreProcessor #56 +- `test_preprocess.py` does not fail anymore on Mac M1 #74 +- It's now possible to pass your own train/test split to NeuralNet class #81 +- HDF5Dataset class now is used in the UX #83 +- IndexError running `NeuralNet.train()` has been fixed #89 +- pip installation has been fixed +- Repository has been renamed deeprank-core, and the package deeprank2 #101 +- The zero-division like error from TensorboardBinaryClassificationExporter has been fixed #112 +- h5xplorer is installed through `setup.cfg` file #121 +- Sphinx docs have been fixed #108 diff --git a/CITATION.cff b/CITATION.cff index c7e085c4c..1ff4b8b48 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -13,36 +13,36 @@ authors: family-names: Crocioni email: g.crocioni@esciencecenter.nl affiliation: Netherlands eScience Center - orcid: 'https://orcid.org/0000-0002-0823-0121' + orcid: "https://orcid.org/0000-0002-0823-0121" - given-names: Dani family-names: Bodor email: d.bodor@esciencecenter.nl affiliation: Netherlands eScience Center - orcid: 'https://orcid.org/0000-0003-2109-2349' + orcid: "https://orcid.org/0000-0003-2109-2349" - given-names: Coos family-names: Baakman affiliation: Radboud University Medical Center - orcid: 'https://orcid.org/0000-0003-4317-1566' + orcid: "https://orcid.org/0000-0003-4317-1566" - given-names: Daniel family-names: Rademaker affiliation: Radboud University Medical Center - orcid: 'https://orcid.org/0000-0003-1959-1317' + orcid: "https://orcid.org/0000-0003-1959-1317" - given-names: Gayatri family-names: Ramakrishnan affiliation: Radboud University Medical Center - orcid: 'https://orcid.org/0000-0001-8203-2783' + orcid: "https://orcid.org/0000-0001-8203-2783" - given-names: Sven family-names: van der Burg affiliation: Netherlands eScience Center - orcid: 'https://orcid.org/0000-0003-1250-6968' + orcid: "https://orcid.org/0000-0003-1250-6968" - given-names: Farzaneh Meimandi family-names: Parizi affiliation: Radboud University Medical Center - orcid: 'https://orcid.org/0000-0003-4230-7492' + orcid: "https://orcid.org/0000-0003-4230-7492" - given-names: Li C. family-names: Xue affiliation: Radboud University Medical Center - orcid: 'https://orcid.org/0000-0002-2613-538X' + orcid: "https://orcid.org/0000-0002-2613-538X" abstract: >- DeepRank2 is an open-source deep learning framework for data mining of protein-protein interfaces @@ -56,4 +56,4 @@ keywords: license: Apache-2.0 commit: 4e8823758ba03f824b4281f5689cb6a335ab2f6c version: "2.1.2" -date-released: '2023-12-21' +date-released: "2023-12-21" diff --git a/README.dev.md b/README.dev.md index 7b00c6847..9fdd58b1c 100644 --- a/README.dev.md +++ b/README.dev.md @@ -50,6 +50,8 @@ If you are using VS code, please install and activate the [Ruff extension](https Otherwise, please ensure check both linting (`ruff fix .`) and formatting (`ruff format .`) before requesting a review. +We use [prettier](https://prettier.io/) for formatting most other files. If you are editing or adding non-python files and using VS code, the [Prettier extension](https://marketplace.visualstudio.com/items?itemName=esbenp.prettier-vscode) can be installed to auto-format these files as well. + ## Versioning Bumping the version across all files is done before creating a new package release, running `bump2version [part]` from command line after having installed [bump2version](https://pypi.org/project/bump2version/) on your local environment. Instead of `[part]`, type the part of the version to increase, e.g. minor. The settings in `.bumpversion.cfg` will take care of updating all the files containing version strings. @@ -57,9 +59,12 @@ Bumping the version across all files is done before creating a new package relea ## Branching workflow We use a [Git Flow](https://nvie.com/posts/a-successful-git-branching-model/)-inspired branching workflow for development. DeepRank2's repository is based on two main branches with infinite lifetime: + - `main` — this branch contains production (stable) code. All development code is merged into `main` in sometime. - `dev` — this branch contains pre-production code. When the features are finished then they are merged into `dev`. + During the development cycle, three main supporting branches are used: + - Feature branches - Branches that branch off from `dev` and must merge into `dev`: used to develop new features for the upcoming releases. - Hotfix branches - Branches that branch off from `main` and must merge into `main` and `dev`: necessary to act immediately upon an undesired status of `main`. - Release branches - Branches that branch off from `dev` and must merge into `main` and `dev`: support preparation of a new production release. They allow many minor bug to be fixed and preparation of meta-data for a release. @@ -77,12 +82,13 @@ During the development cycle, three main supporting branches are used: 1. Branch from `dev` and prepare the branch for the release (e.g., removing the unnecessary dev files such as the current one, fix minor bugs if necessary). 2. [Bump the version](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#versioning). 3. Verify that the information in `CITATION.cff` is correct (update the release date), and that `.zenodo.json` contains equivalent data. -5. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests). -6. Go to https://github.com/DeepRank/deeprank2/releases and draft a new release; create a new tag for the release, generate release notes automatically and adjust them, and finally publish the release as latest. This will trigger [a GitHub action](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. -7. Update the doi in `CITATION.cff` with the one corresponding to the new release. +4. Merge the release branch into `main` (and `dev`), and [run the tests](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md#running-the-tests). +5. Go to https://github.com/DeepRank/deeprank2/releases and draft a new release; create a new tag for the release, generate release notes automatically and adjust them, and finally publish the release as latest. This will trigger [a GitHub action](https://github.com/DeepRank/deeprank2/actions/workflows/release.yml) that will take care of publishing the package on PyPi. +6. Update the doi in `CITATION.cff` with the one corresponding to the new release. ## UML Code-base class diagrams updated on 02/11/2023, generated with https://www.gituml.com (save the images and open them in the browser for zooming). + - Data processing classes and functions: - ML pipeline classes and functions: diff --git a/README.md b/README.md index 2407ba66a..d0ae4356b 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # Deeprank2 -| Badges | | -|:----:|----| -| **fairness** | [![fair-software.eu](https://img.shields.io/badge/fair--software.eu-%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F-green)](https://fair-software.eu) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/6403/badge)](https://bestpractices.coreinfrastructure.org/projects/6403) | -| **package** | [![PyPI version](https://badge.fury.io/py/deeprank2.svg)](https://badge.fury.io/py/deeprank2) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/f3f98b2d1883493ead50e3acaa23f2cc)](https://app.codacy.com/gh/DeepRank/deeprank2?utm_source=github.com&utm_medium=referral&utm_content=DeepRank/deeprank2&utm_campaign=Badge_Grade) | -| **docs** | [![Documentation Status](https://readthedocs.org/projects/deeprank2/badge/?version=latest)](https://deeprank2.readthedocs.io/en/latest/?badge=latest) [![DOI](https://zenodo.org/badge/450496579.svg)](https://zenodo.org/badge/latestdoi/450496579) | -| **tests** | [![Build Status](https://github.com/DeepRank/deeprank2/actions/workflows/build.yml/badge.svg)](https://github.com/DeepRank/deeprank2/actions) ![Linting status](https://github.com/DeepRank/deeprank2/actions/workflows/linting.yml/badge.svg?branch=main) [![Coverage Status](https://coveralls.io/repos/github/DeepRank/deeprank2/badge.svg?branch=main)](https://coveralls.io/github/DeepRank/deeprank2?branch=main) ![Python](https://img.shields.io/badge/python-3.10-blue.svg) ![Python](https://img.shields.io/badge/python-3.11-blue.svg) | -| **running on** | ![Ubuntu](https://img.shields.io/badge/Ubuntu-E95420?style=for-the-badge&logo=ubuntu&logoColor=white) | -| **license** | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/license/apache-2-0/) | +| Badges | | +| :------------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **fairness** | [![fair-software.eu](https://img.shields.io/badge/fair--software.eu-%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F%20%20%E2%97%8F-green)](https://fair-software.eu) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/6403/badge)](https://bestpractices.coreinfrastructure.org/projects/6403) | +| **package** | [![PyPI version](https://badge.fury.io/py/deeprank2.svg)](https://badge.fury.io/py/deeprank2) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/f3f98b2d1883493ead50e3acaa23f2cc)](https://app.codacy.com/gh/DeepRank/deeprank2?utm_source=github.com&utm_medium=referral&utm_content=DeepRank/deeprank2&utm_campaign=Badge_Grade) | +| **docs** | [![Documentation Status](https://readthedocs.org/projects/deeprank2/badge/?version=latest)](https://deeprank2.readthedocs.io/en/latest/?badge=latest) [![DOI](https://zenodo.org/badge/450496579.svg)](https://zenodo.org/badge/latestdoi/450496579) | +| **tests** | [![Build Status](https://github.com/DeepRank/deeprank2/actions/workflows/build.yml/badge.svg)](https://github.com/DeepRank/deeprank2/actions) ![Linting status](https://github.com/DeepRank/deeprank2/actions/workflows/linting.yml/badge.svg?branch=main) [![Coverage Status](https://coveralls.io/repos/github/DeepRank/deeprank2/badge.svg?branch=main)](https://coveralls.io/github/DeepRank/deeprank2?branch=main) ![Python](https://img.shields.io/badge/python-3.10-blue.svg) ![Python](https://img.shields.io/badge/python-3.11-blue.svg) | +| **running on** | ![Ubuntu](https://img.shields.io/badge/Ubuntu-E95420?style=for-the-badge&logo=ubuntu&logoColor=white) | +| **license** | [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/license/apache-2-0/) | ## Overview @@ -18,6 +18,7 @@ DeepRank2 is an open-source deep learning (DL) framework for data mining of prot DeepRank2 allows for transformation of (pdb formatted) molecular data into 3D representations (either grids or graphs) containing structural and physico-chemical information, which can be used for training neural networks. DeepRank2 also offers a pre-implemented training pipeline, using either [CNNs](https://en.wikipedia.org/wiki/Convolutional_neural_network) (for grids) or [GNNs](https://en.wikipedia.org/wiki/Graph_neural_network) (for graphs), as well as output exporters for evaluating performances. Main features: + - Predefined atom-level and residue-level feature types - e.g. atom/residue type, charge, size, potential energy - All features' documentation is available [here](https://deeprank2.readthedocs.io/en/latest/features.html) @@ -58,19 +59,19 @@ The package officially supports ubuntu-latest OS only, whose functioning is wide Before installing deeprank2 you need to install some dependencies. We advise to use a [conda environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) with Python >= 3.10 installed. The following dependency installation instructions are updated as of 14/09/2023, but in case of issues during installation always refer to the official documentation which is linked below: -* [MSMS](https://anaconda.org/bioconda/msms): `conda install -c bioconda msms`. - * [Here](https://ssbio.readthedocs.io/en/latest/instructions/msms.html) for MacOS with M1 chip users. -* [PyTorch](https://pytorch.org/get-started/locally/) - * We support torch's CPU library as well as CUDA. - * Currently, the package is tested using [PyTorch 2.0.1](https://pytorch.org/get-started/previous-versions/#v201). -* [PyG](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`. -* [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/) - * Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4: - * on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`. - * on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. Alternatively, follow [this](https://github.com/PDB-REDO/libcifpp/issues/49) thread. -* [GCC](https://gcc.gnu.org/install/) - * Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. -* For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html). +- [MSMS](https://anaconda.org/bioconda/msms): `conda install -c bioconda msms`. + - [Here](https://ssbio.readthedocs.io/en/latest/instructions/msms.html) for MacOS with M1 chip users. +- [PyTorch](https://pytorch.org/get-started/locally/) + - We support torch's CPU library as well as CUDA. + - Currently, the package is tested using [PyTorch 2.0.1](https://pytorch.org/get-started/previous-versions/#v201). +- [PyG](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`. +- [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/) + - Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4: + - on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`. + - on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. Alternatively, follow [this](https://github.com/PDB-REDO/libcifpp/issues/49) thread. +- [GCC](https://gcc.gnu.org/install/) + - Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. +- For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html). ## Deeprank2 Package @@ -110,6 +111,7 @@ For more details, see the [extended documentation](https://deeprank2.rtfd.io/). For each protein-protein complex (or protein structure containing a missense variant), a `Query` can be created and added to the `QueryCollection` object, to be processed later on. Two subtypes of `Query` exist: `ProteinProteinInterfaceQuery` and `SingleResidueVariantQuery`. A `Query` takes as inputs: + - a `.pdb` file, representing the protein-protein structure, - the resolution (`"residue"` or `"atom"`), i.e. whether each node should represent an amino acid residue or an atom, - the ids of the chains composing the structure, and @@ -335,7 +337,7 @@ from deeprank2.utils.exporters import HDF5OutputExporter trainer = Trainer( NaiveNetwork, - dataset_test = dataset_test, + dataset_test = dataset_test, pretrained_model = "", output_exporters = [HDF5OutputExporter("")] ) @@ -350,11 +352,11 @@ For more details about how to run a pre-trained model on new data, see the [docs We measured the efficiency of data generation in DeepRank2 using the tutorials' [PDB files](https://zenodo.org/record/8187806) (~100 data points per data set), averaging the results run on Apple M1 Pro, using a single CPU. Parameter settings were: atomic resolution, `distance_cutoff` of 5.5 Å, radius (for SRV only) of 10 Å. The [features modules](https://deeprank2.readthedocs.io/en/latest/features.html) used were `components`, `contact`, `exposure`, `irc`, `secondary_structure`, `surfacearea`, for a total of 33 features for PPIs and 26 for SRVs (the latter do not use `irc` features). -| | Data processing speed
[seconds/structure] | Memory
[megabyte/structure] | -|------|:--------------------------------------------------------:|:--------------------------------------------------------:| +| | Data processing speed
[seconds/structure] | Memory
[megabyte/structure] | +| ---- | :--------------------------------------------------------------------: | :--------------------------------------------------------------------: | | PPIs | graph only: **2.99** (std 0.23)
graph+grid: **11.35** (std 1.30) | graph only: **0.54** (std 0.07)
graph+grid: **16.09** (std 0.44) | -| SRVs | graph only: **2.20** (std 0.08)
graph+grid: **2.85** (std 0.10) | graph only: **0.05** (std 0.01)
graph+grid: **17.52** (std 0.59) | +| SRVs | graph only: **2.20** (std 0.08)
graph+grid: **2.85** (std 0.10) | graph only: **0.05** (std 0.01)
graph+grid: **17.52** (std 0.59) | ## Package development -If you're looking for developer documentation, go [here](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md). \ No newline at end of file +If you're looking for developer documentation, go [here](https://github.com/DeepRank/deeprank2/blob/dev/README.dev.md). diff --git a/docs/docking.md b/docs/docking.md index 16406dcfe..eb5f83ae3 100644 --- a/docs/docking.md +++ b/docs/docking.md @@ -13,7 +13,7 @@ See https://onlinelibrary.wiley.com/doi/abs/10.1002/prot.10393 for more details ## Compute and add docking scores -The following code snippet shows an example of how to use deeprank2 to compute the docking scores for a given docking model, and how to add one of the scores (e.g., `dockq`) as a target to the already processed data. +The following code snippet shows an example of how to use deeprank2 to compute the docking scores for a given docking model, and how to add one of the scores (e.g., `dockq`) as a target to the already processed data. ```python from deeprank2.tools.target import add_target, compute_ppi_scores @@ -42,4 +42,4 @@ add_target("", "dockq", "") ``` -After having run the above code snipped, each processed data point within the indicated HDF5 file will contain a new Dataset called "dockq", containing the value computed through `compute_ppi_scores`. \ No newline at end of file +After having run the above code snipped, each processed data point within the indicated HDF5 file will contain a new Dataset called "dockq", containing the value computed through `compute_ppi_scores`. diff --git a/docs/features.md b/docs/features.md index 32b052bf6..de3e9e639 100644 --- a/docs/features.md +++ b/docs/features.md @@ -2,7 +2,6 @@ Features implemented in the code-base are defined in `deeprank2.feature` subpackage. - ## Custom features Users can add custom features by creating a new module and placing it in `deeprank2.feature` subpackage. One requirement for any feature module is to implement an `add_features` function, as shown below. This will be used in `deeprank2.models.query` to add the features to the nodes or edges of the graph. @@ -24,12 +23,15 @@ def add_features( The following is a brief description of the features already implemented in the code-base, for each features' module. ## Default node features + For atomic graphs, when features relate to residues then _all_ atoms of one residue receive the feature value for that residue. ### Core properties of atoms and residues: `deeprank2.features.components` + These features relate to the chemical components (atoms and amino acid residues) of which the graph is composed. Detailed information and descrepancies between sources are described can be found in `deeprank2.domain.aminoacidlist.py`. #### Atom properties: + These features are only used in atomic graphs. - `atom_type`: One-hot encoding of the atomic element. Options are: C, O, N, S, P, H. @@ -37,6 +39,7 @@ These features are only used in atomic graphs. - `pdb_occupancy`: Proportion of structures where the atom was detected at this position (float). In some cases a single atom was detected at different positions, in which case separate structures exist whose occupancies sum to 1. Only the highest occupancy atom is used by deeprank2. #### Residue properties: + - `res_type`: One-hot encoding of the amino acid residue (size 20). - `polarity`: One-hot encoding of the polarity of the amino acid (options: NONPOLAR, POLAR, NEGATIVE, POSITIVE). Note that sources vary on the polarity for few of the amino acids; see detailed information in `deeprank2.domain.aminoacidlist.py`. - `res_size`: The number of non-hydrogen atoms in the side chain (int). @@ -47,28 +50,32 @@ These features are only used in atomic graphs. - `hb_donors`, `hb_acceptors`: The number of hydrogen bond donor/acceptor atoms in the residue (int). Hydrogen bonds are noncovalent intermolecular interactions formed between an hydrogen atom (partially positively charged) bound to a small, highly electronegative atom (O, N, F) with an unshared electron pair. #### Properties related to variant residues: + These features are only used in SingleResidueVariant queries. - `variant_res`: One-hot encoding of variant amino acid (size 20). - `diff_charge`, `diff_polarity`, `diff_size`, `diff_mass`, `diff_pI`, `diff_hb_donors`, `diff_hb_acceptors`: Subtraction of the wildtype value of indicated feature from the variant value. For example, if the variant has 4 hb_donors and the wildtype has 5, then `diff_hb_donors == -1`. ### Conservation features: `deeprank2.features.conservation` + These features relate to the conservation state of individual residues. - `pssm`: [Position-specific scoring matrix](https://en.wikipedia.org/wiki/Position_weight_matrix) (also known as position weight matrix, PWM) values relative to the residue, is a score of the conservation of the amino acid along all 20 amino acids. - `info_content`: Information content is the difference between the given PSSM for an amino acid and a uniform distribution (float). -- `conservation` (only used in SingleResidueVariant queries): Conservation of the wild type amino acid (float). *More details required.* +- `conservation` (only used in SingleResidueVariant queries): Conservation of the wild type amino acid (float). _More details required._ - `diff_conservation` (only used in SingleResidueVariant queries): Subtraction of wildtype conservation from the variant conservation (float). ### Protein context features: #### Surface exposure: `deeprank2.features.exposure` + These features relate to the exposure of residues to the surface, and are computed using [biopython](https://biopython.org/docs/1.81/api/Bio.PDB.html). Note that these features can only be calculated per residue and not per atom. - `res_depth`: [Residue depth](https://en.wikipedia.org/wiki/Residue_depth) is the average distance (in Å) of the residue to the closest molecule of bulk water (float). See also [`Bio.PDB.ResidueDepth`](https://biopython.org/docs/1.75/api/Bio.PDB.ResidueDepth.html). - `hse`: [Half sphere exposure (HSE)](https://en.wikipedia.org/wiki/Half_sphere_exposure) is a protein solvent exposure measure indicating how buried an amino acid residue is in a protein (3 float values, see [Bio.PDB.HSExposure](https://biopython.org/docs/dev/api/Bio.PDB.HSExposure.html#module-Bio.PDB.HSExposure) for details). #### Surface accessibility: `deeprank2.features.surfacearea` + These features relate to the surface area of the residue, and are computed using [freesasa](https://freesasa.github.io). Note that these features can only be calculated per residue and not per atom. - `sasa`: [Solvent-Accessible Surface Area](https://en.wikipedia.org/wiki/Accessible_surface_area) is the surface area (in Å^2) of a biomolecule that is accessible to the solvent (float). @@ -76,31 +83,36 @@ These features relate to the surface area of the residue, and are computed using #### Secondary structure: `deeprank2.features.secondary_structure` -- `sec_struct`: One-hot encoding of the [DSSP](https://en.wikipedia.org/wiki/DSSP_(algorithm)) assigned secondary structure of the amino acid, using the three major classes (HELIX, STRAND, COIL). Calculated using [DSSP4](https://github.com/PDB-REDO/dssp). +- `sec_struct`: One-hot encoding of the [DSSP]() assigned secondary structure of the amino acid, using the three major classes (HELIX, STRAND, COIL). Calculated using [DSSP4](https://github.com/PDB-REDO/dssp). #### Inter-residue contacts (IRCs): `deeprank2.features.irc` + These features are only calculated for ProteinProteinInterface queries. - `irc_total`: The number of residues on the other chain that are within a cutoff distance of 5.5 Å (int). - `irc_nonpolar_nonpolar`, `irc_nonpolar_polar`, `irc_nonpolar_negative`, `irc_nonpolar_positive`, `irc_polar_polar`, `irc_polar_negative`, `irc_polar_positive`, `irc_negative_negative`, `irc_positive_positive`, `irc_negative_positive`: As above, but for specific residue polarity pairings. - ## Default edge features ### Contact features: `deeprank2.features.contact` + These features relate to relationships between individual nodes. For atomic graphs, when features relate to residues then _all_ atoms of one residue receive the feature value for that residue. #### Distance: + - `distance`: Interatomic distance between atoms in Å, computed from the xyz atomic coordinates taken from the .pdb file (float). For residue graphs, the the minimum distance between any atom of each residues is used. #### Structure: + These features relate to the structural relationship between nodes. + - `same_chain`: Boolean indicating whether the edge connects nodes belonging to the same chain (1) or separate chains (0). - `same_res`: Boolean indicating whether atoms belong to the same residue (1) or separate residues (0). Only used in atomic graphs. - `covalent`: Boolean indicating whether nodes are covalently bound (1) or not (0). Note that covalency is not directly assessed, but any edge with a maximum distance of 2.1 Å is considered covalent. #### Nonbond energies: + These features measure nonbond energy potentials between nodes. For residue graphs, the pairwise sum of potentials for all atoms from each residue is used. Note that no distance cutoff is used and the radius of influence is assumed to be infinite, although the potentials tends to 0 at large distance. Also edges are only assigned within a given cutoff radius when graphs are created. Nonbond energies are set to 0 for any atom pairs (on the same chain) that are within a cutoff radius of 3.6 Å, as these are assumed to be covalent neighbors or linked by no more than 2 covalent bonds (i.e. 1-3 pairs). diff --git a/docs/getstarted.md b/docs/getstarted.md index 893fff15e..bb5d3c311 100644 --- a/docs/getstarted.md +++ b/docs/getstarted.md @@ -9,6 +9,7 @@ For more details, see the [extended documentation](https://deeprank2.rtfd.io/). For each protein-protein complex (or protein structure containing a missense variant), a `Query` can be created and added to the `QueryCollection` object, to be processed later on. Two subtypes of `Query` exist: `ProteinProteinInterfaceQuery` and `SingleResidueVariantQuery`. A `Query` takes as inputs: + - a `.pdb` file, representing the protein-protein structure, - the resolution (`"residue"` or `"atom"`), i.e. whether each node should represent an amino acid residue or an atom, - the ids of the chains composing the structure, and @@ -432,7 +433,7 @@ hdf5_paths = queries.process( feature_modules = 'all') ``` -Then, the GraphDataset instance for the newly processed data can be created. Do this by specifying the path for the pre-trained model in `train_source`, together with the path to the HDF5 files just created. Note that there is no need of setting the dataset's parameters, since they are inherited from the information saved in the pre-trained model. +Then, the GraphDataset instance for the newly processed data can be created. Do this by specifying the path for the pre-trained model in `train_source`, together with the path to the HDF5 files just created. Note that there is no need of setting the dataset's parameters, since they are inherited from the information saved in the pre-trained model. ```python from deeprank2.dataset import GraphDataset @@ -452,7 +453,7 @@ from deeprank2.utils.exporters import HDF5OutputExporter trainer = Trainer( NaiveNetwork, - dataset_test = dataset_test, + dataset_test = dataset_test, pretrained_model = "", output_exporters = [HDF5OutputExporter("")] ) @@ -460,7 +461,7 @@ trainer = Trainer( trainer.test() ``` -The results can then be read in a Pandas Dataframe and visualized: +The results can then be read in a Pandas Dataframe and visualized: ```python import os diff --git a/docs/index.rst b/docs/index.rst index b7740fd89..3825a07c2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ DeepRank2 |version| documentation DeepRank2 is an open-source deep learning (DL) framework for data mining of protein-protein interfaces (PPIs) or single-residue variants (SRVs). This package is an improved and unified version of three previously developed packages: `DeepRank`_, `DeepRank-GNN`_, and `DeepRank-Mut`_. -DeepRank2 allows for transformation of (pdb formatted) molecular data into 3D representations (either grids or graphs) containing structural and physico-chemical information, which can be used for training neural networks. DeepRank2 also offers a pre-implemented training pipeline, using either `convolutional neural networks`_ (for grids) or `graph neural networks`_ (for graphs), as well as output exporters for evaluating performances. +DeepRank2 allows for transformation of (pdb formatted) molecular data into 3D representations (either grids or graphs) containing structural and physico-chemical information, which can be used for training neural networks. DeepRank2 also offers a pre-implemented training pipeline, using either `convolutional neural networks`_ (for grids) or `graph neural networks`_ (for graphs), as well as output exporters for evaluating performances. Main features: @@ -32,7 +32,7 @@ Getting started :maxdepth: 2 :caption: Getting started :hidden: - + installation getstarted @@ -60,7 +60,7 @@ Notes Package reference =========== - + .. toctree:: :caption: API :hidden: diff --git a/docs/installation.md b/docs/installation.md index 7d1b0af9b..0b40d3197 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,24 +1,23 @@ # Installation -The package officially supports ubuntu-latest OS only, whose functioning is widely tested through the continuous integration workflows. +The package officially supports ubuntu-latest OS only, whose functioning is widely tested through the continuous integration workflows. ## Dependencies Before installing deeprank2 you need to install some dependencies. We advise to use a [conda environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) with Python >= 3.10 installed. The following dependency installation instructions are updated as of 14/09/2023, but in case of issues during installation always refer to the official documentation which is linked below: -* [MSMS](https://anaconda.org/bioconda/msms): `conda install -c bioconda msms`. - * [Here](https://ssbio.readthedocs.io/en/latest/instructions/msms.html) for MacOS with M1 chip users. -* [PyTorch](https://pytorch.org/get-started/locally/) - * We support torch's CPU library as well as CUDA. -* [PyG](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`. -* [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/) - * Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4: - * on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`. - * on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. -* [GCC](https://gcc.gnu.org/install/) - * Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. -* For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html). - +- [MSMS](https://anaconda.org/bioconda/msms): `conda install -c bioconda msms`. + - [Here](https://ssbio.readthedocs.io/en/latest/instructions/msms.html) for MacOS with M1 chip users. +- [PyTorch](https://pytorch.org/get-started/locally/) + - We support torch's CPU library as well as CUDA. +- [PyG](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and its optional dependencies: `torch_scatter`, `torch_sparse`, `torch_cluster`, `torch_spline_conv`. +- [DSSP 4](https://swift.cmbi.umcn.nl/gv/dssp/) +- Check if `dssp` is installed: `dssp --version`. If this gives an error or shows a version lower than 4: +- on ubuntu 22.04 or newer: `sudo apt-get install dssp`. If the package cannot be located, first run `sudo apt-get update`. +- on older versions of ubuntu or on mac or lacking sudo priviliges: install from [here](https://github.com/pdb-redo/dssp), following the instructions listed. +- [GCC](https://gcc.gnu.org/install/) +- Check if gcc is installed: `gcc --version`. If this gives an error, run `sudo apt-get install gcc`. +- For MacOS with M1 chip users only install [the conda version of PyTables](https://www.pytables.org/usersguide/installation.html). ## Deeprank2 Package diff --git a/docs/requirements.txt b/docs/requirements.txt index 3558cc33a..aac391aaf 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,4 +2,4 @@ sphinx==5.3.0 sphinx_rtd_theme==1.1.1 readthedocs-sphinx-search==0.1.1 myst-parser -toml \ No newline at end of file +toml diff --git a/paper/paper.bib b/paper/paper.bib index b7ba32e15..350a6cac5 100644 --- a/paper/paper.bib +++ b/paper/paper.bib @@ -256,4 +256,4 @@ @article{modeller author={Sanchez, R and Sali, A}, journal={Google Scholar There is no corresponding record for this reference}, year={1997} -} \ No newline at end of file +} diff --git a/paper/paper.md b/paper/paper.md index a8dbbb0f5..13a10b3ae 100644 --- a/paper/paper.md +++ b/paper/paper.md @@ -1,5 +1,5 @@ --- -title: 'DeepRank2: Mining 3D Protein Structures with Geometric Deep Learning' +title: "DeepRank2: Mining 3D Protein Structures with Geometric Deep Learning" tags: - Python - PyTorch @@ -44,18 +44,18 @@ authors: orcid: 0000-0002-2613-538X affiliation: 2 affiliations: - - name: Netherlands eScience Center, Amsterdam, The Netherlands - index: 1 - - name: Department of Medical BioSciences, Radboud University Medical Center, Nijmegen, The Netherlands - index: 2 - - name: Independent Researcher - index: 3 + - name: Netherlands eScience Center, Amsterdam, The Netherlands + index: 1 + - name: Department of Medical BioSciences, Radboud University Medical Center, Nijmegen, The Netherlands + index: 2 + - name: Independent Researcher + index: 3 date: 08 August 2023 bibliography: paper.bib - --- # Summary + [comment]: <> (CHECK FOR AUTHORS: Do the summary describe the high-level functionality and purpose of the software for a diverse, non-specialist audience?) We present DeepRank2, a deep learning (DL) framework geared towards making predictions on 3D protein structures for variety of biologically relevant applications. Our software can be used for predicting structural properties in drug design, immunotherapy, or designing novel proteins, among other fields. DeepRank2 allows for transformation and storage of 3D representations of both protein-protein interfaces (PPIs) and protein single-residue variants (SRVs) into either graphs or volumetric grids containing structural and physico-chemical information. These can be used for training neural networks for a variety of patterns of interest, using either our pre-implemented training pipeline for graph neural networks (GNNs) or convolutional neural networks (CNNs) or external pipelines. The entire framework flowchart is visualized in \autoref{fig:flowchart}. The package is fully open-source, follows the community-endorsed FAIR principles for research software, provides user-friendly APIs, publicily available [documentation](https://deeprank2.readthedocs.io/en/latest/), and in-depth [tutorials](https://github.com/DeepRank/deeprank2/blob/main/tutorials/TUTORIAL.md). @@ -88,8 +88,9 @@ These limitations create a growing demand for a generic and flexible DL framewor DeepRank2 allows to transform and store 3D representations of both PPIs and SRVs into 3D grids or graphs containing both geometric and physico-chemical information, and provides a DL pipeline that can be used for training pre-implemented neural networks for a given pattern of interest to the user. DeepRank2 is an improved and unified version of three previously developed packages: [DeepRank](https://github.com/DeepRank/deeprank), [DeepRank-GNN](https://github.com/DeepRank/Deeprank-GNN), and [DeepRank-Mut](https://github.com/DeepRank/DeepRank-Mut). As input, DeepRank2 takes [PDB-formatted](https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html) atomic structures, which is one of the standard and most widely used formats in the field of structural biology. These are mapped to graphs, where nodes can represent either residues or atoms, as chosen by the user, and edges represent the interactions between them. The user can configure two types of 3D structures as input for the featurization phase: -- PPIs, for mining interaction patterns within protein-protein complexes; -- SRVs, for mining mutation phenotypes within protein structures. + +- PPIs, for mining interaction patterns within protein-protein complexes; +- SRVs, for mining mutation phenotypes within protein structures. The physico-chemical and geometrical features are then computed and assigned to each node and edge. The user can choose which features to generate from several pre-existing options defined in the package, or define custom features modules, as explained in the documentation. Examples of pre-defined node features are the type of the amino acid, its size and polarity, as well as more complex features such as its buried surface area and secondary structure features. Examples of pre-defined edge features are distance, covalency, and potential energy. A detailed list of predefined features can be found in the [documentation's features page](https://deeprank2.readthedocs.io/en/latest/features.html). Graphs can either be used directly or mapped to volumetric grids (i.e., 3D image-like representations), together with their features. Multiple CPUs can be used to parallelize and speed up the featurization process. The processed data are saved into HDF5 files, designed to efficiently store and organize big data. Users can then use the data for any ML or DL framework suited for the application. Specifically, graphs can be used for the training of GNNs, and 3D grids can be used for the training of CNNs. diff --git a/tests/data/hdf5/_generate_testdata.ipynb b/tests/data/hdf5/_generate_testdata.ipynb index 617b4a8f1..b2fc2677f 100644 --- a/tests/data/hdf5/_generate_testdata.ipynb +++ b/tests/data/hdf5/_generate_testdata.ipynb @@ -7,31 +7,36 @@ "outputs": [], "source": [ "from pathlib import Path\n", + "\n", "import pkg_resources as pkg\n", + "\n", "PATH_DEEPRANK_CORE = Path(pkg.resource_filename(\"deeprank2\", \"\"))\n", "ROOT = PATH_DEEPRANK_CORE.parent\n", "PATH_TEST = ROOT / \"tests\"\n", - "from deeprank2.query import (\n", - " QueryCollection,\n", - " ProteinProteinInterfaceQuery,\n", - " SingleResidueVariantQuery)\n", - "from deeprank2.tools.target import compute_ppi_scores\n", - "from deeprank2.dataset import save_hdf5_keys\n", - "from deeprank2.domain.aminoacidlist import alanine, phenylalanine\n", "import glob\n", "import os\n", "import re\n", "import sys\n", + "\n", "import h5py\n", "import numpy as np\n", - "import pandas as pd" + "import pandas as pd\n", + "\n", + "from deeprank2.dataset import save_hdf5_keys\n", + "from deeprank2.domain.aminoacidlist import alanine, phenylalanine\n", + "from deeprank2.query import (\n", + " ProteinProteinInterfaceQuery,\n", + " QueryCollection,\n", + " SingleResidueVariantQuery,\n", + ")\n", + "from deeprank2.tools.target import compute_ppi_scores" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "- Generating 1ATN_ppi.hdf5" + "- Generating 1ATN_ppi.hdf5\n" ] }, { @@ -41,7 +46,9 @@ "outputs": [], "source": [ "import warnings\n", + "\n", "from Bio import BiopythonWarning\n", + "\n", "from deeprank2.utils.grid import GridSettings, MapMethod\n", "\n", "with warnings.catch_warnings():\n", @@ -57,37 +64,38 @@ " str(PATH_TEST / \"data/pdb/1ATN/1ATN_1w.pdb\"),\n", " str(PATH_TEST / \"data/pdb/1ATN/1ATN_2w.pdb\"),\n", " str(PATH_TEST / \"data/pdb/1ATN/1ATN_3w.pdb\"),\n", - " str(PATH_TEST / \"data/pdb/1ATN/1ATN_4w.pdb\")]\n", + " str(PATH_TEST / \"data/pdb/1ATN/1ATN_4w.pdb\"),\n", + " ]\n", "\n", " queries = QueryCollection()\n", "\n", " for pdb_path in pdb_paths:\n", " # Append data points\n", " targets = compute_ppi_scores(pdb_path, ref_path)\n", - " queries.add(ProteinProteinInterfaceQuery(\n", - " pdb_path = pdb_path,\n", - " resolution = \"residue\",\n", - " chain_ids = [chain_id1, chain_id2],\n", - " targets = targets,\n", - " pssm_paths = {\n", - " chain_id1: pssm_path1,\n", - " chain_id2: pssm_path2\n", - " }\n", - " ))\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_path,\n", + " resolution=\"residue\",\n", + " chain_ids=[chain_id1, chain_id2],\n", + " targets=targets,\n", + " pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n", + " )\n", + " )\n", "\n", " # Generate graphs and save them in hdf5 files\n", - " output_paths = queries.process(cpu_count=1,\n", - " prefix='1ATN_ppi',\n", - " grid_settings=GridSettings([20, 20, 20], [20.0, 20.0, 20.0]),\n", - " grid_map_method=MapMethod.GAUSSIAN,\n", - " )" + " output_paths = queries.process(\n", + " cpu_count=1,\n", + " prefix=\"1ATN_ppi\",\n", + " grid_settings=GridSettings([20, 20, 20], [20.0, 20.0, 20.0]),\n", + " grid_map_method=MapMethod.GAUSSIAN,\n", + " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "- Generating residue.hdf5" + "- Generating residue.hdf5\n" ] }, { @@ -97,54 +105,53 @@ "outputs": [], "source": [ "# Local data\n", - "project_folder = '/home/dbodor/git/DeepRank/deeprank-core/tests/data/sample_25_07122022/'\n", - "csv_file_name = 'BA_pMHCI_human_quantitative.csv'\n", - "models_folder_name = 'exp_nmers_all_HLA_quantitative'\n", - "data = 'pMHCI'\n", - "resolution = 'residue' # either 'residue' or 'atom'\n", - "influence_radius = 15 # max distance in Å between two interacting residues/atoms of two proteins\n", - "max_edge_length = 15 # max distance in Å between to create an edge\n", + "project_folder = \"/home/dbodor/git/DeepRank/deeprank-core/tests/data/sample_25_07122022/\"\n", + "csv_file_name = \"BA_pMHCI_human_quantitative.csv\"\n", + "models_folder_name = \"exp_nmers_all_HLA_quantitative\"\n", + "data = \"pMHCI\"\n", + "resolution = \"residue\" # either 'residue' or 'atom'\n", + "influence_radius = 15 # max distance in Å between two interacting residues/atoms of two proteins\n", + "max_edge_length = 15 # max distance in Å between to create an edge\n", "\n", - "csv_file_path = f'{project_folder}data/external/processed/I/{csv_file_name}'\n", - "models_folder_path = f'{project_folder}data/{data}/features_input_folder/{models_folder_name}'\n", + "csv_file_path = f\"{project_folder}data/external/processed/I/{csv_file_name}\"\n", + "models_folder_path = f\"{project_folder}data/{data}/features_input_folder/{models_folder_name}\"\n", "\n", - "pdb_files = glob.glob(os.path.join(models_folder_path + '/pdb', '*.pdb'))\n", + "pdb_files = glob.glob(os.path.join(models_folder_path + \"/pdb\", \"*.pdb\"))\n", "pdb_files.sort()\n", - "print(f'{len(pdb_files)} pdbs found.')\n", - "pssm_m = glob.glob(os.path.join(models_folder_path + '/pssm', '*.M.*.pssm'))\n", + "print(f\"{len(pdb_files)} pdbs found.\")\n", + "pssm_m = glob.glob(os.path.join(models_folder_path + \"/pssm\", \"*.M.*.pssm\"))\n", "pssm_m.sort()\n", - "print(f'{len(pssm_m)} MHC pssms found.')\n", - "pssm_p = glob.glob(os.path.join(models_folder_path + '/pssm', '*.P.*.pssm'))\n", + "print(f\"{len(pssm_m)} MHC pssms found.\")\n", + "pssm_p = glob.glob(os.path.join(models_folder_path + \"/pssm\", \"*.P.*.pssm\"))\n", "pssm_p.sort()\n", - "print(f'{len(pssm_p)} peptide pssms found.')\n", + "print(f\"{len(pssm_p)} peptide pssms found.\")\n", "csv_data = pd.read_csv(csv_file_path)\n", "csv_data.cluster = csv_data.cluster.fillna(-1)\n", - "pdb_ids_csv = [pdb_file.split('/')[-1].split('.')[0].replace('-', '_') for pdb_file in pdb_files]\n", - "clusters = [csv_data[csv_data.ID == pdb_id].cluster.values[0] for pdb_id in pdb_ids_csv]\n", - "bas = [csv_data[csv_data.ID == pdb_id].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n", + "pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0].replace(\"-\", \"_\") for pdb_file in pdb_files]\n", + "clusters = [csv_data[pdb_id == csv_data.ID].cluster.values[0] for pdb_id in pdb_ids_csv]\n", + "bas = [csv_data[pdb_id == csv_data.ID].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n", "\n", "queries = QueryCollection()\n", - "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", "for i in range(len(pdb_files)):\n", " queries.add(\n", " ProteinProteinInterfaceQuery(\n", - " pdb_path = pdb_files[i],\n", - " resolution = \"residue\",\n", - " chain_ids = [\"M\", \"P\"],\n", - " influence_radius = influence_radius,\n", - " max_edge_length = max_edge_length,\n", - " targets = {\n", - " 'binary': int(float(bas[i]) <= 500), # binary target value\n", - " 'BA': bas[i], # continuous target value\n", - " 'cluster': clusters[i]\n", - " },\n", - " pssm_paths = {\n", - " \"M\": pssm_m[i],\n", - " \"P\": pssm_p[i]\n", - " }))\n", - "print(f'Queries created and ready to be processed.\\n')\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"residue\",\n", + " chain_ids=[\"M\", \"P\"],\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " targets={\n", + " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", + " \"BA\": bas[i], # continuous target value\n", + " \"cluster\": clusters[i],\n", + " },\n", + " pssm_paths={\"M\": pssm_m[i], \"P\": pssm_p[i]},\n", + " )\n", + " )\n", + "print(\"Queries created and ready to be processed.\\n\")\n", "\n", - "output_paths = queries.process(prefix='residue')\n", + "output_paths = queries.process(prefix=\"residue\")\n", "print(output_paths)" ] }, @@ -153,7 +160,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Generating train.hdf5, valid.hdf5, test.hdf5" + "- Generating train.hdf5, valid.hdf5, test.hdf5\n" ] }, { @@ -163,20 +170,19 @@ "outputs": [], "source": [ "# dividing hdf5 file in train, valid, test\n", - "hdf5_path = 'residue.hdf5'\n", + "hdf5_path = \"residue.hdf5\"\n", "train_clusters = [3, 4, 5, 2]\n", "val_clusters = [1, 8]\n", "test_clusters = [6]\n", - "target = 'target_values'\n", - "feature = 'cluster'\n", + "target = \"target_values\"\n", + "feature = \"cluster\"\n", "\n", "clusters = {}\n", "train_ids = []\n", "val_ids = []\n", "test_ids = []\n", "\n", - "with h5py.File(hdf5_path, 'r') as hdf5:\n", - "\n", + "with h5py.File(hdf5_path, \"r\") as hdf5:\n", " for key in hdf5.keys():\n", " feature_value = float(hdf5[key][target][feature][()])\n", " if feature_value in train_clusters:\n", @@ -191,24 +197,23 @@ " else:\n", " clusters[int(feature_value)] = 1\n", "\n", + " print(f\"Trainset contains {len(train_ids)} data points, {round(100*len(train_ids)/len(hdf5.keys()), 2)}% of the total data.\")\n", + " print(f\"Validation set contains {len(val_ids)} data points, {round(100*len(val_ids)/len(hdf5.keys()), 2)}% of the total data.\")\n", + " print(f\"Test set contains {len(test_ids)} data points, {round(100*len(test_ids)/len(hdf5.keys()), 2)}% of the total data.\\n\")\n", "\n", - " print(f'Trainset contains {len(train_ids)} data points, {round(100*len(train_ids)/len(hdf5.keys()), 2)}% of the total data.')\n", - " print(f'Validation set contains {len(val_ids)} data points, {round(100*len(val_ids)/len(hdf5.keys()), 2)}% of the total data.')\n", - " print(f'Test set contains {len(test_ids)} data points, {round(100*len(test_ids)/len(hdf5.keys()), 2)}% of the total data.\\n')\n", - "\n", - " for (key, value) in dict(sorted(clusters.items(), key=lambda x:x[1], reverse=True)).items():\n", - " print(f'Group with value {key}: {value} data points, {round(100*value/len(hdf5.keys()), 2)}% of total data.')\n", + " for key, value in dict(sorted(clusters.items(), key=lambda x: x[1], reverse=True)).items():\n", + " print(f\"Group with value {key}: {value} data points, {round(100*value/len(hdf5.keys()), 2)}% of total data.\")\n", "\n", - "save_hdf5_keys(hdf5_path, train_ids, 'train.hdf5', hardcopy = True)\n", - "save_hdf5_keys(hdf5_path, val_ids, 'valid.hdf5', hardcopy = True)\n", - "save_hdf5_keys(hdf5_path, test_ids, 'test.hdf5', hardcopy = True)" + "save_hdf5_keys(hdf5_path, train_ids, \"train.hdf5\", hardcopy=True)\n", + "save_hdf5_keys(hdf5_path, val_ids, \"valid.hdf5\", hardcopy=True)\n", + "save_hdf5_keys(hdf5_path, test_ids, \"test.hdf5\", hardcopy=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "- Generating variants.hdf5" + "- Generating variants.hdf5\n" ] }, { @@ -225,21 +230,19 @@ "\n", "for number in range(1, count_queries + 1):\n", " query = SingleResidueVariantQuery(\n", - " pdb_path = pdb_path,\n", - " resolution = \"residue\",\n", - " chain_ids = \"A\",\n", - " variant_residue_number = number,\n", - " insertion_code = None,\n", - " wildtype_amino_acid = alanine,\n", - " variant_amino_acid = phenylalanine,\n", - " pssm_paths = {\n", - " \"A\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.A.pdb.pssm\"),\n", - " \"B\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.B.pdb.pssm\")},\n", - " targets = targets\n", + " pdb_path=pdb_path,\n", + " resolution=\"residue\",\n", + " chain_ids=\"A\",\n", + " variant_residue_number=number,\n", + " insertion_code=None,\n", + " wildtype_amino_acid=alanine,\n", + " variant_amino_acid=phenylalanine,\n", + " pssm_paths={\"A\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.A.pdb.pssm\"), \"B\": str(PATH_TEST / \"data/pssm/3C8P/3C8P.B.pdb.pssm\")},\n", + " targets=targets,\n", " )\n", " queries.add(query)\n", "\n", - "output_paths = queries.process(cpu_count = 1, prefix='variants')" + "output_paths = queries.process(cpu_count=1, prefix=\"variants\")" ] }, { @@ -247,7 +250,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Generating atom.hdf5" + "- Generating atom.hdf5\n" ] }, { @@ -265,26 +268,22 @@ " str(PATH_TEST / \"data/pdb/1ATN/1ATN_1w.pdb\"),\n", " str(PATH_TEST / \"data/pdb/1ATN/1ATN_2w.pdb\"),\n", " str(PATH_TEST / \"data/pdb/1ATN/1ATN_3w.pdb\"),\n", - " str(PATH_TEST / \"data/pdb/1ATN/1ATN_4w.pdb\")]\n", + " str(PATH_TEST / \"data/pdb/1ATN/1ATN_4w.pdb\"),\n", + "]\n", "\n", "queries = QueryCollection()\n", "\n", "for pdb_path in pdb_paths:\n", " # Append data points\n", " targets = compute_ppi_scores(pdb_path, ref_path)\n", - " queries.add(ProteinProteinInterfaceQuery(\n", - " pdb_path = pdb_path,\n", - " resolution=\"atom\",\n", - " chain_ids = [chain_id1, chain_id2],\n", - " targets = targets,\n", - " pssm_paths = {\n", - " chain_id1: pssm_path1,\n", - " chain_id2: pssm_path2\n", - " }\n", - " ))\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_path, resolution=\"atom\", chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}\n", + " )\n", + " )\n", "\n", "# Generate graphs and save them in hdf5 files\n", - "output_paths = queries.process(cpu_count=1, prefix = 'atom')" + "output_paths = queries.process(cpu_count=1, prefix=\"atom\")" ] } ], diff --git a/tutorials/TUTORIAL.md b/tutorials/TUTORIAL.md index 5ede2257a..d41135c2e 100644 --- a/tutorials/TUTORIAL.md +++ b/tutorials/TUTORIAL.md @@ -1,9 +1,10 @@ ## Introduction The tutorial notebooks in this folder can be run to learn how to use DeepRank2. -- There are two tutorial notebooks for data generation, which demonstrate how to create *.hdf5-formatted input training data from raw *.pdb-formatted data using DeepRank2. + +- There are two tutorial notebooks for data generation, which demonstrate how to create _.hdf5-formatted input training data from raw _.pdb-formatted data using DeepRank2. - protein-protein interface (PPI) data ([data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb)); - - single-residue variant (SRV) data ([data_generation_srv.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_srv.ipynb)). + - single-residue variant (SRV) data ([data_generation_srv.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_srv.ipynb)). - The [training tutorial](tutorials/training_ppi.ipynb) will demonstrate how to train neural networks using DeepRank2. ### Use cases @@ -35,9 +36,9 @@ PDB models and target data used in this tutorial have been retrieved from [Ramak - Navigate to your deeprank2 folder. - Run `pytest tests`. All tests should pass at this point. - ## Running the notebooks The tutorial notebooks can be run: + - from inside your IDE, if it has that functionality (e.g., VS Code), - on JupyterLab by navigating to the tutorials directory in your terminal and running `jupyter-lab`. diff --git a/tutorials/data_generation_ppi.ipynb b/tutorials/data_generation_ppi.ipynb index 1dcc41972..2bcc213d1 100644 --- a/tutorials/data_generation_ppi.ipynb +++ b/tutorials/data_generation_ppi.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Data preparation for protein-protein interfaces" + "# Data preparation for protein-protein interfaces\n" ] }, { @@ -17,9 +17,9 @@ "\n", "\n", "\n", - "This tutorial will demonstrate the use of DeepRank2 for generating protein-protein interface (PPI) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files](https://en.wikipedia.org/wiki/Protein_Data_Bank_(file_format)) of protein-protein complexes as input.\n", + "This tutorial will demonstrate the use of DeepRank2 for generating protein-protein interface (PPI) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein-protein complexes as input.\n", "\n", - "In this data processing phase, for each protein-protein complex an interface is selected according to a distance threshold that the user can customize, and it is mapped to a graph. Nodes either represent residues or atoms, and edges are the interactions between them. Each node and edge can have several different features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. The mapped data are finally saved into HDF5 files, and can be used for later models' training (for details go to [training_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs)." + "In this data processing phase, for each protein-protein complex an interface is selected according to a distance threshold that the user can customize, and it is mapped to a graph. Nodes either represent residues or atoms, and edges are the interactions between them. Each node and edge can have several different features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. The mapped data are finally saved into HDF5 files, and can be used for later models' training (for details go to [training_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" ] }, { @@ -31,7 +31,7 @@ "\n", "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/8349335). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", - "Note that the dataset contains only 100 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users." + "Note that the dataset contains only 100 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" ] }, { @@ -39,7 +39,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Utilities" + "## Utilities\n" ] }, { @@ -47,7 +47,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Libraries" + "### Libraries\n" ] }, { @@ -55,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The libraries needed for this tutorial:" + "The libraries needed for this tutorial:\n" ] }, { @@ -82,7 +82,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Raw files and paths" + "### Raw files and paths\n" ] }, { @@ -90,7 +90,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The paths for reading raw data and saving the processed ones:" + "The paths for reading raw data and saving the processed ones:\n" ] }, { @@ -112,7 +112,7 @@ "source": [ "- Raw data are PDB files in `data_raw/ppi/pdb/`, which contains atomic coordinates of the protein-protein complexes of interest, so in our case of pMHC complexes.\n", "- Target data, so in our case the BA values for the pMHC complex, are in `data_raw/ppi/BA_values.csv`.\n", - "- The final PPI processed data will be saved in `data_processed/ppi/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below." + "- The final PPI processed data will be saved in `data_processed/ppi/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" ] }, { @@ -120,7 +120,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names in a list and the BA target values from a CSV containing the IDs of the PDB models as well:" + "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names in a list and the BA target values from a CSV containing the IDs of the PDB models as well:\n" ] }, { @@ -130,14 +130,15 @@ "outputs": [], "source": [ "def get_pdb_files_and_target_data(data_path):\n", - "\tcsv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n", - "\tpdb_files = glob.glob(os.path.join(data_path, \"pdb\", '*.pdb'))\n", - "\tpdb_files.sort()\n", - "\tpdb_ids_csv = [pdb_file.split('/')[-1].split('.')[0] for pdb_file in pdb_files]\n", - "\tcsv_data_indexed = csv_data.set_index('ID')\n", - "\tcsv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n", - "\tbas = csv_data_indexed.measurement_value.values.tolist()\n", - "\treturn pdb_files, bas\n", + " csv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n", + " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.pdb\"))\n", + " pdb_files.sort()\n", + " pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0] for pdb_file in pdb_files]\n", + " csv_data_indexed = csv_data.set_index(\"ID\")\n", + " csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n", + " bas = csv_data_indexed.measurement_value.values.tolist()\n", + " return pdb_files, bas\n", + "\n", "\n", "pdb_files, bas = get_pdb_files_and_target_data(data_path)" ] @@ -147,7 +148,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## `QueryCollection` and `Query` objects" + "## `QueryCollection` and `Query` objects\n" ] }, { @@ -165,7 +166,7 @@ "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the interface that will be included in the graph.\n", "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add two targets: \"BA\" and \"binary\". The first represents the actual BA value of the complex in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) a binding one.\n", "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", - "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), in the form of .pssm files. PSSMs are optional and will not be used in this tutorial." + "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" ] }, { @@ -173,7 +174,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Residue-level PPIs using `ProteinProteinInterfaceQuery`" + "## Residue-level PPIs using `ProteinProteinInterfaceQuery`\n" ] }, { @@ -187,25 +188,27 @@ "influence_radius = 8 # max distance in Å between two interacting residues/atoms of two proteins\n", "max_edge_length = 8\n", "\n", - "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(\n", - "\t\tProteinProteinInterfaceQuery(\n", - "\t\t\tpdb_path = pdb_files[i],\n", - "\t\t\tresolution = \"residue\",\n", - "\t\t\tchain_ids = [\"M\", \"P\"],\n", - "\t\t\tinfluence_radius = influence_radius,\n", - "\t\t\tmax_edge_length = max_edge_length,\n", - "\t\t\ttargets = {\n", - "\t\t\t\t'binary': int(float(bas[i]) <= 500), # binary target value\n", - "\t\t\t\t'BA': bas[i], # continuous target value\n", - "\t\t\t\t}))\n", - "\tcount +=1\n", - "\tif count % 20 == 0:\n", - "\t\tprint(f'{count} queries added to the collection.')\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"residue\",\n", + " chain_ids=[\"M\", \"P\"],\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " targets={\n", + " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", + " \"BA\": bas[i], # continuous target value\n", + " },\n", + " )\n", + " )\n", + " count += 1\n", + " if count % 20 == 0:\n", + " print(f\"{count} queries added to the collection.\")\n", "\n", - "print(f'Queries ready to be processed.\\n')" + "print(f\"Queries ready to be processed.\\n\")" ] }, { @@ -221,7 +224,7 @@ "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", - "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved." + "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" ] }, { @@ -230,20 +233,22 @@ "metadata": {}, "outputs": [], "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - "\t# the number of points on the x, y, z edges of the cube\n", - "\tpoints_counts = [35, 30, 30],\n", - "\t# x, y, z sizes of the box in Å\n", - "\tsizes = [1.0, 1.0, 1.0])\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", "\n", "queries.process(\n", - "\tprefix = os.path.join(processed_data_path, \"residue\", \"proc\"),\n", - "\tfeature_modules = [components, contact],\n", - " cpu_count = 8,\n", - "\tcombine_output = False,\n", - "\tgrid_settings = grid_settings,\n", - "\tgrid_map_method = grid_map_method)\n", + " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", "\n", "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" ] @@ -309,7 +314,7 @@ "\n", "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", "\n", - "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it." + "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" ] }, { @@ -319,7 +324,7 @@ "source": [ "#### Pandas dataframe\n", "\n", - "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph. " + "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" ] }, { @@ -339,7 +344,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can also generate histograms for looking at the features distributions. An example:" + "We can also generate histograms for looking at the features distributions. An example:\n" ] }, { @@ -349,12 +354,10 @@ "outputs": [], "source": [ "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", - "dataset.save_hist(\n", - " features = [\"res_mass\", \"distance\", \"electrostatic\"],\n", - " fname = fname)\n", + "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize = (15,10))\n", + "plt.figure(figsize=(15, 10))\n", "fig = plt.imshow(im)\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)" @@ -369,12 +372,12 @@ "\n", "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", " As representative example, the following is the structure for `BA-100600.pdb` seen from HDF5View:\n", - " \n", + "\n", " \n", "\n", - " Using this tool you can inspect the values of the features visually, for each data point. \n", + " Using this tool you can inspect the values of the features visually, for each data point.\n", "\n", - "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:" + "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" ] }, { @@ -386,19 +389,19 @@ "with h5py.File(processed_data[0], \"r\") as hdf5:\n", " # List of all graphs in hdf5, each graph representing a ppi\n", " ids = list(hdf5.keys())\n", - " print(f'IDs of PPIs in {processed_data[0]}: {ids}')\n", - " node_features = list(hdf5[ids[0]][\"node_features\"]) \n", - " print(f'Node features: {node_features}')\n", + " print(f\"IDs of PPIs in {processed_data[0]}: {ids}\")\n", + " node_features = list(hdf5[ids[0]][\"node_features\"])\n", + " print(f\"Node features: {node_features}\")\n", " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", - " print(f'Edge features: {edge_features}')\n", + " print(f\"Edge features: {edge_features}\")\n", " target_features = list(hdf5[ids[0]][\"target_values\"])\n", - " print(f'Targets features: {target_features}')\n", + " print(f\"Targets features: {target_features}\")\n", " # Polarity feature for ids[0], numpy.ndarray\n", " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", - " print(f'Polarity feature shape: {node_feat_polarity.shape}')\n", + " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", " # Electrostatic feature for ids[0], numpy.ndarray\n", " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", - " print(f'Electrostatic feature shape: {edge_feat_electrostatic.shape}')" + " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" ] }, { @@ -408,7 +411,7 @@ "source": [ "## Atomic-level PPIs using `ProteinProteinInterfaceQuery`\n", "\n", - "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level. " + "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" ] }, { @@ -422,25 +425,27 @@ "influence_radius = 5 # max distance in Å between two interacting residues/atoms of two proteins\n", "max_edge_length = 5\n", "\n", - "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(\n", - "\t\tProteinProteinInterfaceQuery(\n", - "\t\t\tpdb_path = pdb_files[i],\n", - "\t\t\tresolution = \"atom\",\n", - "\t\t\tchain_ids = [\"M\",\"P\"],\n", - "\t\t\tinfluence_radius = influence_radius,\n", - "\t\t\tmax_edge_length = max_edge_length,\n", - "\t\t\ttargets = {\n", - "\t\t\t\t'binary': int(float(bas[i]) <= 500), # binary target value\n", - "\t\t\t\t'BA': bas[i], # continuous target value\n", - "\t\t\t\t}))\n", - "\tcount +=1\n", - "\tif count % 20 == 0:\n", - "\t\tprint(f'{count} queries added to the collection.')\n", + " queries.add(\n", + " ProteinProteinInterfaceQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"atom\",\n", + " chain_ids=[\"M\", \"P\"],\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " targets={\n", + " \"binary\": int(float(bas[i]) <= 500), # binary target value\n", + " \"BA\": bas[i], # continuous target value\n", + " },\n", + " )\n", + " )\n", + " count += 1\n", + " if count % 20 == 0:\n", + " print(f\"{count} queries added to the collection.\")\n", "\n", - "print(f'Queries ready to be processed.\\n')" + "print(f\"Queries ready to be processed.\\n\")" ] }, { @@ -449,20 +454,22 @@ "metadata": {}, "outputs": [], "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - "\t# the number of points on the x, y, z edges of the cube\n", - "\tpoints_counts = [35, 30, 30],\n", - "\t# x, y, z sizes of the box in Å\n", - "\tsizes = [1.0, 1.0, 1.0])\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", "\n", "queries.process(\n", - "\tprefix = os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", - "\tfeature_modules = [components, contact],\n", - " cpu_count = 8,\n", - "\tcombine_output = False,\n", - "\tgrid_settings = grid_settings,\n", - "\tgrid_map_method = grid_map_method)\n", + " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", "\n", "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" ] @@ -472,7 +479,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Again, the data can be inspected using `hdf5_to_pandas` function." + "Again, the data can be inspected using `hdf5_to_pandas` function.\n" ] }, { @@ -494,12 +501,10 @@ "outputs": [], "source": [ "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", - "dataset.save_hist(\n", - " features = \"atom_charge\",\n", - " fname = fname)\n", + "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize = (8,8))\n", + "plt.figure(figsize=(8, 8))\n", "fig = plt.imshow(im)\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)" @@ -510,7 +515,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation." + "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" ] } ], diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index f8c3ffddf..d4835ff4f 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Data preparation for single-residue variants" + "# Data preparation for single-residue variants\n" ] }, { @@ -17,9 +17,9 @@ "\n", "\n", "\n", - "This tutorial will demonstrate the use of DeepRank2 for generating single-residue variants (SRVs) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files](https://en.wikipedia.org/wiki/Protein_Data_Bank_(file_format)) of protein structures as input.\n", + "This tutorial will demonstrate the use of DeepRank2 for generating single-residue variants (SRVs) graphs and saving them as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, using [PBD files]() of protein structures as input.\n", "\n", - "In this data processing phase, a local neighborhood around the mutated residue is selected for each SRV according to a radius threshold that the user can customize. All atoms or residues within the threshold are mapped as the nodes to a graph and the interactions between them are the edges of the graph. Each node and edge can have several distinct (structural or physico-chemical) features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. Finally, the mapped data are saved as HDF5 files, which can be used for training predictive models (for details see [training_ppi.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs)." + "In this data processing phase, a local neighborhood around the mutated residue is selected for each SRV according to a radius threshold that the user can customize. All atoms or residues within the threshold are mapped as the nodes to a graph and the interactions between them are the edges of the graph. Each node and edge can have several distinct (structural or physico-chemical) features, which are generated and added during the processing phase as well. Optionally, the graphs can be mapped to volumetric grids (i.e., 3D image-like representations), together with their features. Finally, the mapped data are saved as HDF5 files, which can be used for training predictive models (for details see [training_ppi.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/training_ppi.ipynb) tutorial). In particular, graphs can be used for the training of Graph Neural Networks (GNNs), and grids can be used for the training of Convolutional Neural Networks (CNNs).\n" ] }, { @@ -31,7 +31,7 @@ "\n", "The example data used in this tutorial are available on Zenodo at [this record address](https://zenodo.org/record/8349335). To download the raw data used in this tutorial, please visit the link and download `data_raw.zip`. Unzip it, and save the `data_raw/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", - "Note that the dataset contains only 96 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users." + "Note that the dataset contains only 96 data points, which is not enough to develop an impactful predictive model, and the scope of its use is indeed only demonstrative and informative for the users.\n" ] }, { @@ -39,7 +39,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Utilities" + "## Utilities\n" ] }, { @@ -47,7 +47,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Libraries" + "### Libraries\n" ] }, { @@ -55,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The libraries needed for this tutorial:" + "The libraries needed for this tutorial:\n" ] }, { @@ -75,7 +75,7 @@ "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", "from deeprank2.features import components, contact\n", "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset\n" + "from deeprank2.dataset import GraphDataset" ] }, { @@ -83,7 +83,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Raw files and paths" + "### Raw files and paths\n" ] }, { @@ -91,7 +91,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The paths for reading raw data and saving the processed ones:" + "The paths for reading raw data and saving the processed ones:\n" ] }, { @@ -111,9 +111,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Raw data are PDB files in `data_raw/srv/pdb/`, which contains atomic coordinates of the protein structure containing the variant. \n", + "- Raw data are PDB files in `data_raw/srv/pdb/`, which contains atomic coordinates of the protein structure containing the variant.\n", "- Target data, so in our case pathogenic versus benign labels, are in `data_raw/srv/srv_target_values.csv`.\n", - "- The final SRV processed data will be saved in `data_processed/srv/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below." + "- The final SRV processed data will be saved in `data_processed/srv/` folder, which in turns contains a folder for residue-level data and another one for atomic-level data. More details about such different levels will come a few cells below.\n" ] }, { @@ -121,7 +121,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names, SRVs information and target values in a list from the CSV:" + "`get_pdb_files_and_target_data` is an helper function used to retrieve the raw pdb files names, SRVs information and target values in a list from the CSV:\n" ] }, { @@ -131,19 +131,20 @@ "outputs": [], "source": [ "def get_pdb_files_and_target_data(data_path):\n", - "\tcsv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values.csv\"))\n", - "\tpdb_files = glob.glob(os.path.join(data_path, \"pdb\", '*.ent'))\n", - "\tpdb_files.sort()\n", - "\tpdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n", - "\tcsv_data_indexed = csv_data.set_index('pdb_file')\n", - "\tcsv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", - "\tres_numbers = csv_data_indexed.res_number.values.tolist()\n", - "\tres_wildtypes = csv_data_indexed.res_wildtype.values.tolist()\n", - "\tres_variants = csv_data_indexed.res_variant.values.tolist()\n", - "\ttargets = csv_data_indexed.target.values.tolist()\n", - "\tpdb_names = csv_data_indexed.index.values.tolist()\n", - "\tpdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n", - "\treturn pdb_files, res_numbers, res_wildtypes, res_variants, targets\n", + " csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values.csv\"))\n", + " pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.ent\"))\n", + " pdb_files.sort()\n", + " pdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n", + " csv_data_indexed = csv_data.set_index(\"pdb_file\")\n", + " csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n", + " res_numbers = csv_data_indexed.res_number.values.tolist()\n", + " res_wildtypes = csv_data_indexed.res_wildtype.values.tolist()\n", + " res_variants = csv_data_indexed.res_variant.values.tolist()\n", + " targets = csv_data_indexed.target.values.tolist()\n", + " pdb_names = csv_data_indexed.index.values.tolist()\n", + " pdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n", + " return pdb_files, res_numbers, res_wildtypes, res_variants, targets\n", + "\n", "\n", "pdb_files, res_numbers, res_wildtypes, res_variants, targets = get_pdb_files_and_target_data(data_path)" ] @@ -153,7 +154,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## `QueryCollection` and `Query` objects" + "## `QueryCollection` and `Query` objects\n" ] }, { @@ -163,7 +164,6 @@ "source": [ "For each SRV, so for each data point, a query can be created and added to the `QueryCollection` object, to be processed later on. Different types of queries exist, based on the molecular resolution needed:\n", "\n", - "\n", "A query takes as inputs:\n", "\n", "- A `.pdb` file, representing the protein structure containing the SRV.\n", @@ -171,12 +171,12 @@ "- The chain id of the SRV.\n", "- The residue number of the missense mutation.\n", "- The insertion code, used when two residues have the same numbering. The combination of residue numbering and insertion code defines the unique residue.\n", - "- The wildtype amino acid. \n", - "- The variant amino acid. \n", + "- The wildtype amino acid.\n", + "- The variant amino acid.\n", "- The interaction radius, which determines the threshold distance (in Ångström) for residues/atoms surrounding the mutation that will be included in the graph.\n", - "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add a 0 if the SRV belongs to the benign class, and 1 if it belongs to the pathogenic one. \n", + "- The target values associated with the query. For each query/data point, in the use case demonstrated in this tutorial will add a 0 if the SRV belongs to the benign class, and 1 if it belongs to the pathogenic one.\n", "- The max edge distance, which is the maximum distance between two nodes to generate an edge between them.\n", - "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), per chain identifier, in the form of .pssm files. PSSMs are optional and will not be used in this tutorial." + "- Optional: The correspondent [Position-Specific Scoring Matrices (PSSMs)](https://en.wikipedia.org/wiki/Position_weight_matrix), per chain identifier, in the form of .pssm files. PSSMs are optional and will not be used in this tutorial.\n" ] }, { @@ -184,7 +184,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Residue-level SRV: `SingleResidueVariantQuery`" + "## Residue-level SRV: `SingleResidueVariantQuery`\n" ] }, { @@ -195,29 +195,31 @@ "source": [ "queries = QueryCollection()\n", "\n", - "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", - "max_edge_length = 4.5 # ??\n", + "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", + "max_edge_length = 4.5 # ??\n", "\n", - "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(SingleResidueVariantQuery(\n", - "\t\tpdb_path = pdb_files[i],\n", - "\t\tresolution = \"residue\",\n", - "\t\tchain_ids = \"A\",\n", - "\t\tvariant_residue_number = res_numbers[i],\n", - "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", - "\t\ttargets = {'binary': targets[i]},\n", - "\t\tinfluence_radius = influence_radius,\n", - "\t\tmax_edge_length = max_edge_length,\n", - "\t\t))\n", - "\tcount +=1\n", - "\tif count % 20 == 0:\n", - "\t\tprint(f'{count} queries added to the collection.')\n", - "\n", - "print(f'Queries ready to be processed.\\n')" + " queries.add(\n", + " SingleResidueVariantQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"residue\",\n", + " chain_ids=\"A\",\n", + " variant_residue_number=res_numbers[i],\n", + " insertion_code=None,\n", + " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", + " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", + " targets={\"binary\": targets[i]},\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " )\n", + " )\n", + " count += 1\n", + " if count % 20 == 0:\n", + " print(f\"{count} queries added to the collection.\")\n", + "\n", + "print(f\"Queries ready to be processed.\\n\")" ] }, { @@ -233,7 +235,7 @@ "- `feature_modules` allows you to choose which feature generating modules you want to use. By default, the basic features contained in `deeprank2.features.components` and `deeprank2.features.contact` are generated. Users can add custom features by creating a new module and placing it in the `deeprank2.feature` subpackage. A complete and detailed list of the pre-implemented features per module and more information about how to add custom features can be found [here](https://deeprank2.readthedocs.io/en/latest/features.html).\n", " - Note that all features generated by a module will be added if that module was selected, and there is no way to only generate specific features from that module. However, during the training phase shown in `training_ppi.ipynb`, it is possible to select only a subset of available features.\n", "- `cpu_count` can be used to specify how many processes to be run simultaneously, and will coincide with the number of HDF5 files generated. By default it takes all available CPU cores and HDF5 files are squashed into a single file using the `combine_output` setting.\n", - "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved." + "- Optional: If you want to include grids in the HDF5 files, which represent the mapping of the graphs to a volumetric box, you need to define `grid_settings` and `grid_map_method`, as shown in the example below. If they are `None` (default), only graphs are saved.\n" ] }, { @@ -242,20 +244,22 @@ "metadata": {}, "outputs": [], "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - "\t# the number of points on the x, y, z edges of the cube\n", - "\tpoints_counts = [35, 30, 30],\n", - "\t# x, y, z sizes of the box in Å\n", - "\tsizes = [1.0, 1.0, 1.0])\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", "\n", "queries.process(\n", - "\tprefix = os.path.join(processed_data_path, \"residue\", \"proc\"),\n", - "\tfeature_modules = [components, contact],\n", - " cpu_count = 8,\n", - "\tcombine_output = False,\n", - "\tgrid_settings = grid_settings,\n", - "\tgrid_map_method = grid_map_method)\n", + " prefix=os.path.join(processed_data_path, \"residue\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", "\n", "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"residue\")}.')" ] @@ -328,7 +332,7 @@ "\n", "`edge_features`, `node_features`, `mapped_features` are [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which contain [HDF5 Datasets](https://docs.h5py.org/en/stable/high/dataset.html) (e.g., `_index`, `electrostatic`, etc.), which in turn contains features values in the form of arrays. `edge_features` and `node_features` refer specificly to the graph representation, while `grid_points` and `mapped_features` refer to the grid mapped from the graph. Each data point generated by deeprank2 has the above structure, with the features and the target changing according to the user's settings. Features starting with `_` are present for human inspection of the data, but they are not used for training models.\n", "\n", - "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it." + "It is always a good practice to first explore the data, and then make decision about splitting them in training, test and validation sets. There are different possible ways for doing it.\n" ] }, { @@ -338,7 +342,7 @@ "source": [ "#### Pandas dataframe\n", "\n", - "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph. " + "The edge and node features just generated can be explored by instantiating the `GraphDataset` object, and then using `hdf5_to_pandas` method which converts node and edge features into a [Pandas](https://pandas.pydata.org/) dataframe. Each row represents a ppi in the form of a graph.\n" ] }, { @@ -358,7 +362,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can also generate histograms for looking at the features distributions. An example:" + "We can also generate histograms for looking at the features distributions. An example:\n" ] }, { @@ -368,12 +372,10 @@ "outputs": [], "source": [ "fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n", - "dataset.save_hist(\n", - " features = [\"res_mass\", \"distance\", \"electrostatic\"],\n", - " fname = fname)\n", + "dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize = (15,10))\n", + "plt.figure(figsize=(15, 10))\n", "fig = plt.imshow(im)\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)" @@ -388,12 +390,12 @@ "\n", "- [HDFView](https://www.hdfgroup.org/downloads/hdfview/), a visual tool written in Java for browsing and editing HDF5 files.\n", " As representative example, the following is the structure for `pdb2ooh.ent` seen from HDF5View:\n", - " \n", + "\n", " \n", "\n", - " Using this tool you can inspect the values of the features visually, for each data point. \n", + " Using this tool you can inspect the values of the features visually, for each data point.\n", "\n", - "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:" + "- Python packages such as [h5py](https://docs.h5py.org/en/stable/index.html). Examples:\n" ] }, { @@ -406,19 +408,19 @@ " # List of all graphs in hdf5, each graph representing\n", " # a SRV and its sourrouding environment\n", " ids = list(hdf5.keys())\n", - " print(f'IDs of SRVs in {processed_data[0]}: {ids}')\n", - " node_features = list(hdf5[ids[0]][\"node_features\"]) \n", - " print(f'Node features: {node_features}')\n", + " print(f\"IDs of SRVs in {processed_data[0]}: {ids}\")\n", + " node_features = list(hdf5[ids[0]][\"node_features\"])\n", + " print(f\"Node features: {node_features}\")\n", " edge_features = list(hdf5[ids[0]][\"edge_features\"])\n", - " print(f'Edge features: {edge_features}')\n", + " print(f\"Edge features: {edge_features}\")\n", " target_features = list(hdf5[ids[0]][\"target_values\"])\n", - " print(f'Targets features: {target_features}')\n", + " print(f\"Targets features: {target_features}\")\n", " # Polarity feature for ids[0], numpy.ndarray\n", " node_feat_polarity = hdf5[ids[0]][\"node_features\"][\"polarity\"][:]\n", - " print(f'Polarity feature shape: {node_feat_polarity.shape}')\n", + " print(f\"Polarity feature shape: {node_feat_polarity.shape}\")\n", " # Electrostatic feature for ids[0], numpy.ndarray\n", " edge_feat_electrostatic = hdf5[ids[0]][\"edge_features\"][\"electrostatic\"][:]\n", - " print(f'Electrostatic feature shape: {edge_feat_electrostatic.shape}')" + " print(f\"Electrostatic feature shape: {edge_feat_electrostatic.shape}\")" ] }, { @@ -428,7 +430,7 @@ "source": [ "## Atomic-level SRV: `SingleResidueVariantQuery`\n", "\n", - "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level. " + "Graphs can also be generated at an atomic resolution, very similarly to what has just been done for residue-level.\n" ] }, { @@ -439,29 +441,31 @@ "source": [ "queries = QueryCollection()\n", "\n", - "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", - "max_edge_length = 4.5 # ??\n", + "influence_radius = 10.0 # radius to select the local neighborhood around the SRV\n", + "max_edge_length = 4.5 # ??\n", "\n", - "print(f'Adding {len(pdb_files)} queries to the query collection ...')\n", + "print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n", "count = 0\n", "for i in range(len(pdb_files)):\n", - "\tqueries.add(SingleResidueVariantQuery(\n", - "\t\tpdb_path = pdb_files[i],\n", - " \t\tresolution = \"atom\",\n", - "\t\tchain_ids = \"A\",\n", - "\t\tvariant_residue_number = res_numbers[i],\n", - "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", - "\t\ttargets = {'binary': targets[i]},\n", - "\t\tinfluence_radius = influence_radius,\n", - "\t\tmax_edge_length = max_edge_length,\n", - "\t\t))\n", - "\tcount +=1\n", - "\tif count % 20 == 0:\n", - "\t\tprint(f'{count} queries added to the collection.')\n", - "\n", - "print(f'Queries ready to be processed.\\n')" + " queries.add(\n", + " SingleResidueVariantQuery(\n", + " pdb_path=pdb_files[i],\n", + " resolution=\"atom\",\n", + " chain_ids=\"A\",\n", + " variant_residue_number=res_numbers[i],\n", + " insertion_code=None,\n", + " wildtype_amino_acid=amino_acids_by_code[res_wildtypes[i]],\n", + " variant_amino_acid=amino_acids_by_code[res_variants[i]],\n", + " targets={\"binary\": targets[i]},\n", + " influence_radius=influence_radius,\n", + " max_edge_length=max_edge_length,\n", + " )\n", + " )\n", + " count += 1\n", + " if count % 20 == 0:\n", + " print(f\"{count} queries added to the collection.\")\n", + "\n", + "print(f\"Queries ready to be processed.\\n\")" ] }, { @@ -470,20 +474,22 @@ "metadata": {}, "outputs": [], "source": [ - "grid_settings = GridSettings( # None if you don't want grids\n", - "\t# the number of points on the x, y, z edges of the cube\n", - "\tpoints_counts = [35, 30, 30],\n", - "\t# x, y, z sizes of the box in Å\n", - "\tsizes = [1.0, 1.0, 1.0])\n", - "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", + "grid_settings = GridSettings( # None if you don't want grids\n", + " # the number of points on the x, y, z edges of the cube\n", + " points_counts=[35, 30, 30],\n", + " # x, y, z sizes of the box in Å\n", + " sizes=[1.0, 1.0, 1.0],\n", + ")\n", + "grid_map_method = MapMethod.GAUSSIAN # None if you don't want grids\n", "\n", "queries.process(\n", - "\tprefix = os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", - "\tfeature_modules = [components, contact],\n", - " cpu_count = 8,\n", - "\tcombine_output = False,\n", - "\tgrid_settings = grid_settings,\n", - "\tgrid_map_method = grid_map_method)\n", + " prefix=os.path.join(processed_data_path, \"atomic\", \"proc\"),\n", + " feature_modules=[components, contact],\n", + " cpu_count=8,\n", + " combine_output=False,\n", + " grid_settings=grid_settings,\n", + " grid_map_method=grid_map_method,\n", + ")\n", "\n", "print(f'The queries processing is done. The generated HDF5 files are in {os.path.join(processed_data_path, \"atomic\")}.')" ] @@ -493,7 +499,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Again, the data can be inspected using `hdf5_to_pandas` function." + "Again, the data can be inspected using `hdf5_to_pandas` function.\n" ] }, { @@ -515,12 +521,10 @@ "outputs": [], "source": [ "fname = os.path.join(processed_data_path, \"atomic\", \"atom_charge\")\n", - "dataset.save_hist(\n", - " features = \"atom_charge\",\n", - " fname = fname)\n", + "dataset.save_hist(features=\"atom_charge\", fname=fname)\n", "\n", "im = img.imread(fname + \".png\")\n", - "plt.figure(figsize = (8,8))\n", + "plt.figure(figsize=(8, 8))\n", "fig = plt.imshow(im)\n", "fig.axes.get_xaxis().set_visible(False)\n", "fig.axes.get_yaxis().set_visible(False)" @@ -531,7 +535,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation." + "Note that some of the features are different from the ones generated with the residue-level queries. There are indeed features in `deeprank2.features.components` module which are generated only in atomic graphs, i.e. `atom_type`, `atom_charge`, and `pdb_occupancy`, because they don't make sense only in the atomic graphs' representation.\n" ] } ], diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index 6e0340970..784bc03c7 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Training Neural Networks" + "# Training Neural Networks\n" ] }, { @@ -21,7 +21,7 @@ "\n", "This tutorial assumes that the PPI data of interest have already been generated and saved as [HDF5 files](https://en.wikipedia.org/wiki/Hierarchical_Data_Format), with the data structure that DeepRank2 expects. This data can be generated using the [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb) tutorial or downloaded from Zenodo at [this record address](https://zenodo.org/record/8349335). For more details on the data structure, please refer to the other tutorial, which also contains a detailed description of how the data is generated from PDB files.\n", "\n", - "This tutorial assumes also a basic knowledge of the [PyTorch](https://pytorch.org/) framework, on top of which the machine learning pipeline of DeepRank2 has been developed, for which many online tutorials exist." + "This tutorial assumes also a basic knowledge of the [PyTorch](https://pytorch.org/) framework, on top of which the machine learning pipeline of DeepRank2 has been developed, for which many online tutorials exist.\n" ] }, { @@ -33,9 +33,9 @@ "\n", "If you have previously run `data_generation_ppi.ipynb` or `data_generation_srv.ipynb` notebook, then their output can be directly used as input for this tutorial.\n", "\n", - "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/8349335). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial. \n", + "Alternatively, preprocessed HDF5 files can be downloaded directly from Zenodo at [this record address](https://zenodo.org/record/8349335). To download the data used in this tutorial, please visit the link and download `data_processed.zip`. Unzip it, and save the `data_processed/` folder in the same directory as this notebook. The name and the location of the folder are optional but recommended, as they are the name and the location we will use to refer to the folder throughout the tutorial.\n", "\n", - "Note that the datasets contain only ~100 data points each, which is not enough to develop an impactful predictive model, and the scope of their use is indeed only demonstrative and informative for the users." + "Note that the datasets contain only ~100 data points each, which is not enough to develop an impactful predictive model, and the scope of their use is indeed only demonstrative and informative for the users.\n" ] }, { @@ -43,7 +43,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Utilities" + "## Utilities\n" ] }, { @@ -51,7 +51,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Libraries" + "### Libraries\n" ] }, { @@ -59,7 +59,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The libraries needed for this tutorial:" + "The libraries needed for this tutorial:\n" ] }, { @@ -74,19 +74,15 @@ "import h5py\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import (\n", - " roc_curve,\n", - " auc,\n", - " precision_score,\n", - " recall_score,\n", - " accuracy_score,\n", - " f1_score)\n", + "from sklearn.metrics import roc_curve, auc, precision_score, recall_score, accuracy_score, f1_score\n", "import plotly.express as px\n", "import torch\n", "import numpy as np\n", - "np.seterr(divide = 'ignore')\n", - "np.seterr(invalid='ignore')\n", + "\n", + "np.seterr(divide=\"ignore\")\n", + "np.seterr(invalid=\"ignore\")\n", "import pandas as pd\n", + "\n", "logging.basicConfig(level=logging.INFO)\n", "from deeprank2.dataset import GraphDataset, GridDataset\n", "from deeprank2.trainer import Trainer\n", @@ -94,7 +90,8 @@ "from deeprank2.neuralnets.cnn.model3d import CnnClassification\n", "from deeprank2.utils.exporters import HDF5OutputExporter\n", "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -104,7 +101,7 @@ "source": [ "### Paths and sets\n", "\n", - "The paths for reading the processed data:" + "The paths for reading the processed data:\n" ] }, { @@ -116,8 +113,8 @@ "data_type = \"ppi\"\n", "level = \"residue\"\n", "processed_data_path = os.path.join(\"data_processed\", data_type, level)\n", - "input_data_path = glob.glob(os.path.join(processed_data_path, '*.hdf5'))\n", - "output_path = os.path.join(\"data_processed\", data_type, level) # for saving predictions results" + "input_data_path = glob.glob(os.path.join(processed_data_path, \"*.hdf5\"))\n", + "output_path = os.path.join(\"data_processed\", data_type, level) # for saving predictions results" ] }, { @@ -127,7 +124,7 @@ "source": [ "The `data_type` can be either \"ppi\" or \"srv\", depending on which application the user is most interested in. The `level` can be either \"residue\" or \"atomic\", and refers to the structural resolution, where each node either represents a single residue or a single atom from the molecular structure.\n", "\n", - "In this tutorial, we will use PPI residue-level data by default, but the same code can be applied to SRV or/and atomic-level data with no changes, apart from setting `data_type` and `level` parameters in the cell above." + "In this tutorial, we will use PPI residue-level data by default, but the same code can be applied to SRV or/and atomic-level data with no changes, apart from setting `data_type` and `level` parameters in the cell above.\n" ] }, { @@ -135,7 +132,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A Pandas DataFrame containing data points' IDs and the binary target values can be defined:" + "A Pandas DataFrame containing data points' IDs and the binary target values can be defined:\n" ] }, { @@ -145,14 +142,14 @@ "outputs": [], "source": [ "df_dict = {}\n", - "df_dict['entry'] = []\n", - "df_dict['target'] = []\n", + "df_dict[\"entry\"] = []\n", + "df_dict[\"target\"] = []\n", "for fname in input_data_path:\n", - " with h5py.File(fname, 'r') as hdf5:\n", + " with h5py.File(fname, \"r\") as hdf5:\n", " for mol in hdf5.keys():\n", " target_value = float(hdf5[mol][\"target_values\"][\"binary\"][()])\n", - " df_dict['entry'].append(mol)\n", - " df_dict['target'].append(target_value)\n", + " df_dict[\"entry\"].append(mol)\n", + " df_dict[\"target\"].append(target_value)\n", "\n", "df = pd.DataFrame(data=df_dict)\n", "df.head()" @@ -165,9 +162,9 @@ "source": [ "As explained in [data_generation_ppi.ipynb](https://github.com/DeepRank/deeprank2/blob/main/tutorials/data_generation_ppi.ipynb), for each data point there are two targets: \"BA\" and \"binary\". The first represents the strength of the interaction between two molecules that bind reversibly (interact) in nM, while the second represents its binary mapping, being 0 (BA > 500 nM) a not-binding complex and 1 (BA <= 500 nM) binding one.\n", "\n", - "For SRVs, each data point has a single target, \"binary\", which is 0 if the SRV is considered benign, and 1 if it is pathogenic, as explained in [data_generation_srv.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/data_generation_srv.ipynb). \n", + "For SRVs, each data point has a single target, \"binary\", which is 0 if the SRV is considered benign, and 1 if it is pathogenic, as explained in [data_generation_srv.ipynb](https://github.com/DeepRank/deeprank-core/blob/main/tutorials/data_generation_srv.ipynb).\n", "\n", - "The pandas DataFrame `df` is used only to split data points into training, validation and test sets according to the \"binary\" target - using target stratification to keep the proportion of 0s and 1s constant among the different sets. Training and validation sets will be used during the training for updating the network weights, while the test set will be held out as an independent test and will be used later for the model evaluation." + "The pandas DataFrame `df` is used only to split data points into training, validation and test sets according to the \"binary\" target - using target stratification to keep the proportion of 0s and 1s constant among the different sets. Training and validation sets will be used during the training for updating the network weights, while the test set will be held out as an independent test and will be used later for the model evaluation.\n" ] }, { @@ -179,17 +176,17 @@ "df_train, df_test = train_test_split(df, test_size=0.1, stratify=df.target, random_state=42)\n", "df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)\n", "\n", - "print(f'Data statistics:\\n')\n", - "print(f'Total samples: {len(df)}\\n')\n", - "print(f'Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df))}%')\n", - "print(f'\\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%')\n", - "print(f'\\t- Class 1: {len(df_train[df_train.target == 1])} samples, {round(100*len(df_train[df_train.target == 1])/len(df_train))}%')\n", - "print(f'Validation set: {len(df_valid)} samples, {round(100*len(df_valid)/len(df))}%')\n", - "print(f'\\t- Class 0: {len(df_valid[df_valid.target == 0])} samples, {round(100*len(df_valid[df_valid.target == 0])/len(df_valid))}%')\n", - "print(f'\\t- Class 1: {len(df_valid[df_valid.target == 1])} samples, {round(100*len(df_valid[df_valid.target == 1])/len(df_valid))}%')\n", - "print(f'Testing set: {len(df_test)} samples, {round(100*len(df_test)/len(df))}%')\n", - "print(f'\\t- Class 0: {len(df_test[df_test.target == 0])} samples, {round(100*len(df_test[df_test.target == 0])/len(df_test))}%')\n", - "print(f'\\t- Class 1: {len(df_test[df_test.target == 1])} samples, {round(100*len(df_test[df_test.target == 1])/len(df_test))}%')" + "print(f\"Data statistics:\\n\")\n", + "print(f\"Total samples: {len(df)}\\n\")\n", + "print(f\"Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%\")\n", + "print(f\"\\t- Class 1: {len(df_train[df_train.target == 1])} samples, {round(100*len(df_train[df_train.target == 1])/len(df_train))}%\")\n", + "print(f\"Validation set: {len(df_valid)} samples, {round(100*len(df_valid)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_valid[df_valid.target == 0])} samples, {round(100*len(df_valid[df_valid.target == 0])/len(df_valid))}%\")\n", + "print(f\"\\t- Class 1: {len(df_valid[df_valid.target == 1])} samples, {round(100*len(df_valid[df_valid.target == 1])/len(df_valid))}%\")\n", + "print(f\"Testing set: {len(df_test)} samples, {round(100*len(df_test)/len(df))}%\")\n", + "print(f\"\\t- Class 0: {len(df_test[df_test.target == 0])} samples, {round(100*len(df_test[df_test.target == 0])/len(df_test))}%\")\n", + "print(f\"\\t- Class 1: {len(df_test[df_test.target == 1])} samples, {round(100*len(df_test[df_test.target == 1])/len(df_test))}%\")" ] }, { @@ -199,7 +196,7 @@ "source": [ "## Classification example\n", "\n", - "A GNN and a CNN can be trained for a classification predictive task, which consists in predicting the \"binary\" target values. " + "A GNN and a CNN can be trained for a classification predictive task, which consists in predicting the \"binary\" target values.\n" ] }, { @@ -207,7 +204,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### GNN" + "### GNN\n" ] }, { @@ -215,7 +212,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### GraphDataset" + "#### GraphDataset\n" ] }, { @@ -226,14 +223,15 @@ "For training GNNs the user can create `GraphDataset` instances. This class inherits from `DeeprankDataset` class, which in turns inherits from `Dataset` [PyTorch geometric class](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/dataset.html), a base class for creating graph datasets.\n", "\n", "A few notes about `GraphDataset` parameters:\n", - "- By default, all features contained in the HDF5 files are used, but the user can specify `node_features` and `edge_features` in `GraphDataset` if not all of them are needed. See the [docs](https://deeprank2.readthedocs.io/en/latest/features.html) for more details about all the possible pre-implemented features. \n", + "\n", + "- By default, all features contained in the HDF5 files are used, but the user can specify `node_features` and `edge_features` in `GraphDataset` if not all of them are needed. See the [docs](https://deeprank2.readthedocs.io/en/latest/features.html) for more details about all the possible pre-implemented features.\n", "- For regression, `task` should be set to `regress` and the `target` to `BA`, which is a continuous variable and therefore suitable for regression tasks.\n", "- For the `GraphDataset` class it is possible to define a dictionary to indicate which transformations to apply to the features, being the transformations lambda functions and/or standardization.\n", " - If the `standardize` key is `True`, standardization is applied after transformation. Standardization consists in applying the following formula on each feature's value: ${x' = \\frac{x - \\mu}{\\sigma}}$, being ${\\mu}$ the mean and ${\\sigma}$ the standard deviation. Standardization is a scaling method where the values are centered around mean with a unit standard deviation.\n", " - The transformation to apply can be speficied as a lambda function as a value of the key `transform`, which defaults to `None`.\n", - " - Since in the provided example standardization is applied, the training features' means and standard deviations need to be used for scaling validation and test sets. For doing so, `train_source` parameter is used. When `train_source` parameter is set, it will be used to scale the validation/testing sets. You need to pass `features_transform` to the training dataset only, since in other cases it will be ignored and only the one of `train_source` will be considered. \n", - " - Note that transformations have not currently been implemented for the `GridDataset` class. \n", - " - In the example below a logarithmic transformation and then the standardization are applied to all the features. It is also possible to use specific features as keys for indicating that transformation and/or standardization need to be apply to few features only." + " - Since in the provided example standardization is applied, the training features' means and standard deviations need to be used for scaling validation and test sets. For doing so, `train_source` parameter is used. When `train_source` parameter is set, it will be used to scale the validation/testing sets. You need to pass `features_transform` to the training dataset only, since in other cases it will be ignored and only the one of `train_source` will be considered.\n", + " - Note that transformations have not currently been implemented for the `GridDataset` class.\n", + " - In the example below a logarithmic transformation and then the standardization are applied to all the features. It is also possible to use specific features as keys for indicating that transformation and/or standardization need to be apply to few features only.\n" ] }, { @@ -246,29 +244,29 @@ "task = \"classif\"\n", "node_features = [\"res_type\"]\n", "edge_features = [\"distance\"]\n", - "features_transform = {'all': {'transform': lambda x: np.cbrt(x), 'standardize': True}}\n", + "features_transform = {\"all\": {\"transform\": lambda x: np.cbrt(x), \"standardize\": True}}\n", "\n", - "print('Loading training data...')\n", + "print(\"Loading training data...\")\n", "dataset_train = GraphDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_train.entry), # selects only data points with ids in df_train.entry\n", - " node_features = node_features,\n", - " edge_features = edge_features,\n", - " features_transform = features_transform,\n", - " target = target,\n", - " task = task\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", + " node_features=node_features,\n", + " edge_features=edge_features,\n", + " features_transform=features_transform,\n", + " target=target,\n", + " task=task,\n", ")\n", - "print('\\nLoading validation data...')\n", + "print(\"\\nLoading validation data...\")\n", "dataset_val = GraphDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", - " train_source = dataset_train\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", + " train_source=dataset_train,\n", ")\n", - "print('\\nLoading test data...')\n", + "print(\"\\nLoading test data...\")\n", "dataset_test = GraphDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_test.entry), # selects only data points with ids in df_test.entry\n", - " train_source = dataset_train\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", + " train_source=dataset_train,\n", ")" ] }, @@ -279,7 +277,7 @@ "source": [ "#### Trainer\n", "\n", - "The class `Trainer` implements training, validation and testing of PyTorch-based neural networks. " + "The class `Trainer` implements training, validation and testing of PyTorch-based neural networks.\n" ] }, { @@ -288,10 +286,11 @@ "metadata": {}, "source": [ "A few notes about `Trainer` parameters:\n", + "\n", "- `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. The `Trainer` class takes care of formatting the output shape according to the task. This tutorial uses a simple network, `NaiveNetwork` (implemented in `deeprank2.neuralnets.gnn.naive_gnn`). All GNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank-core/tree/main/deeprank2/neuralnets/gnn) and can be used for training or as a basis for implementing new ones.\n", "- `class_weights` is used for classification tasks only and assigns class weights based on the training dataset content to account for any potential inbalance between the classes. In this case the dataset is balanced (50% 0 and 50% 1), so it is not necessary to use it. It defaults to False.\n", "- `cuda` and `ngpu` are used for indicating whether to use CUDA and how many GPUs. By default, CUDA is not used and `ngpu` is 0.\n", - "- The user can specify a deeprank2 exporter or a custom one in `output_exporters` parameter, together with the path where to save the results. Exporters are used for storing predictions information collected later on during training and testing. Later the results saved by `HDF5OutputExporter` will be read and evaluated." + "- The user can specify a deeprank2 exporter or a custom one in `output_exporters` parameter, together with the path where to save the results. Exporters are used for storing predictions information collected later on during training and testing. Later the results saved by `HDF5OutputExporter` will be read and evaluated.\n" ] }, { @@ -299,7 +298,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "##### Training" + "##### Training\n" ] }, { @@ -309,11 +308,11 @@ "outputs": [], "source": [ "trainer = Trainer(\n", - " neuralnet = NaiveNetwork,\n", - " dataset_train = dataset_train,\n", - " dataset_val = dataset_val,\n", - " dataset_test = dataset_test,\n", - " output_exporters = [HDF5OutputExporter(os.path.join(output_path, f\"gnn_{task}\"))]\n", + " neuralnet=NaiveNetwork,\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"gnn_{task}\"))],\n", ")" ] }, @@ -322,7 +321,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The default optimizer is `torch.optim.Adam`. It is possible to specify optimizer's parameters or to use another PyTorch optimizer object:" + "The default optimizer is `torch.optim.Adam`. It is possible to specify optimizer's parameters or to use another PyTorch optimizer object:\n" ] }, { @@ -348,9 +347,10 @@ "Then the model can be trained using the `train()` method of the `Trainer` class.\n", "\n", "A few notes about `train()` method parameters:\n", + "\n", "- `earlystop_patience`, `earlystop_maxgap` and `min_epoch` are used for controlling early stopping logic. `earlystop_patience` indicates the number of epochs after which the training ends if the validation loss does not improve. `earlystop_maxgap` indicated the maximum difference allowed between validation and training loss, and `min_epoch` is the minimum number of epochs to be reached before evaluating `maxgap`.\n", "- If `validate` is set to `True`, validation is performed on an independent dataset, which has been called `dataset_val` few cells above. If set to `False`, validation is performed on the training dataset itself (not recommended).\n", - "- `num_workers` can be set for indicating how many subprocesses to use for data loading. The default is 0 and it means that the data will be loaded in the main process." + "- `num_workers` can be set for indicating how many subprocesses to use for data loading. The default is 0 and it means that the data will be loaded in the main process.\n" ] }, { @@ -366,20 +366,21 @@ "min_epoch = 10\n", "\n", "trainer.train(\n", - " nepoch = epochs,\n", - " batch_size = batch_size,\n", - " earlystop_patience = earlystop_patience,\n", - " earlystop_maxgap = earlystop_maxgap,\n", - " min_epoch = min_epoch,\n", - " validate = True,\n", - " filename = os.path.join(output_path, f\"gnn_{task}\", \"model.pth.tar\"))\n", + " nepoch=epochs,\n", + " batch_size=batch_size,\n", + " earlystop_patience=earlystop_patience,\n", + " earlystop_maxgap=earlystop_maxgap,\n", + " min_epoch=min_epoch,\n", + " validate=True,\n", + " filename=os.path.join(output_path, f\"gnn_{task}\", \"model.pth.tar\"),\n", + ")\n", "\n", "epoch = trainer.epoch_saved_model\n", "print(f\"Model saved at epoch {epoch}\")\n", "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", - "print(f'Total # of parameters: {pytorch_total_params}')\n", + "print(f\"Total # of parameters: {pytorch_total_params}\")\n", "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", - "print(f'Total # of trainable parameters: {pytorch_trainable_params}')" + "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" ] }, { @@ -389,7 +390,7 @@ "source": [ "##### Testing\n", "\n", - "And the trained model can be tested on `dataset_test`:" + "And the trained model can be tested on `dataset_test`:\n" ] }, { @@ -410,7 +411,7 @@ "\n", "Finally, the results saved by `HDF5OutputExporter` can be inspected, which can be found in the `data/ppi/gnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`. Note that the folder contains the saved pre-trained model as well.\n", "\n", - "`output_exporter.hdf5` contains [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which refer to each phase, e.g. training and testing if both are run, only one of them otherwise. Training phase includes validation results as well. This HDF5 file can be read as a Pandas Dataframe:" + "`output_exporter.hdf5` contains [HDF5 Groups](https://docs.h5py.org/en/stable/high/group.html) which refer to each phase, e.g. training and testing if both are run, only one of them otherwise. Training phase includes validation results as well. This HDF5 file can be read as a Pandas Dataframe:\n" ] }, { @@ -419,8 +420,12 @@ "metadata": {}, "outputs": [], "source": [ - "output_train = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", - "output_test = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", + "output_train = pd.read_hdf(\n", + " os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n", + ")\n", + "output_test = pd.read_hdf(\n", + " os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n", + ")\n", "output_train.head()" ] }, @@ -431,7 +436,7 @@ "source": [ "The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.\n", "\n", - "For example, the loss across the epochs can be plotted for the training and the validation sets:" + "For example, the loss across the epochs can be plotted for the training and the validation sets:\n" ] }, { @@ -440,20 +445,16 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.line(\n", - " output_train,\n", - " x='epoch',\n", - " y='loss',\n", - " color='phase',\n", - " markers=True)\n", + "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", "\n", "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", "\n", "fig.update_layout(\n", - " xaxis_title='Epoch #',\n", - " yaxis_title='Loss',\n", - " title='Loss vs epochs - GNN training',\n", - " width=700, height=400,\n", + " xaxis_title=\"Epoch #\",\n", + " yaxis_title=\"Loss\",\n", + " title=\"Loss vs epochs - GNN training\",\n", + " width=700,\n", + " height=400,\n", ")" ] }, @@ -462,7 +463,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And now a few metrics of interest for classification tasks can be printed out: the [area under the ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve) (AUC), and for a threshold of 0.5 the [precision, recall, accuracy and f1 score](https://en.wikipedia.org/wiki/Precision_and_recall#Definition)." + "And now a few metrics of interest for classification tasks can be printed out: the [area under the ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve) (AUC), and for a threshold of 0.5 the [precision, recall, accuracy and f1 score](https://en.wikipedia.org/wiki/Precision_and_recall#Definition).\n" ] }, { @@ -473,23 +474,23 @@ "source": [ "threshold = 0.5\n", "df = pd.concat([output_train, output_test])\n", - "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == 'testing'))]\n", + "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", "\n", - "for idx, set in enumerate(['training', 'validation', 'testing']):\n", + "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", " df_plot_phase = df_plot[(df_plot.phase == set)]\n", " y_true = df_plot_phase.target\n", " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", "\n", - " print(f'\\nMetrics for {set}:')\n", + " print(f\"\\nMetrics for {set}:\")\n", " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", " auc_score = auc(fpr_roc, tpr_roc)\n", - " print(f'AUC: {round(auc_score, 1)}')\n", - " print(f'Considering a threshold of {threshold}')\n", - " y_pred = (y_score > threshold)*1\n", - " print(f'- Precision: {round(precision_score(y_true, y_pred), 1)}')\n", - " print(f'- Recall: {round(recall_score(y_true, y_pred), 1)}')\n", - " print(f'- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}')\n", - " print(f'- F1: {round(f1_score(y_true, y_pred), 1)}')" + " print(f\"AUC: {round(auc_score, 1)}\")\n", + " print(f\"Considering a threshold of {threshold}\")\n", + " y_pred = (y_score > threshold) * 1\n", + " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", + " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" ] }, { @@ -499,7 +500,7 @@ "source": [ "Note that the poor performance of this network is due to the small number of datapoints used in this tutorial. For a more reliable network we suggest using a number of data points on the order of at least tens of thousands.\n", "\n", - "The same exercise can be repeated but using grids instead of graphs and CNNs instead of GNNs." + "The same exercise can be repeated but using grids instead of graphs and CNNs instead of GNNs.\n" ] }, { @@ -507,7 +508,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### CNN" + "### CNN\n" ] }, { @@ -515,7 +516,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "#### GridDataset" + "#### GridDataset\n" ] }, { @@ -526,8 +527,9 @@ "For training CNNs the user can create `GridDataset` instances.\n", "\n", "A few notes about `GridDataset` parameters:\n", - "- By default, all features contained in the HDF5 files are used, but the user can specify `features` in `GridDataset` if not all of them are needed. Since grids features are derived from node and edge features mapped from graphs to grid, the easiest way to see which features are available is to look at the HDF5 file, as explained in detail in `data_generation_ppi.ipynb` and `data_generation_srv.ipynb`, section \"Other tools\". \n", - "- As is the case for a `GraphDataset`, `task` can be assigned to `regress` and `target` to `BA` to perform a regression task. As mentioned previously, we do not provide sample data to perform a regression task for SRVs." + "\n", + "- By default, all features contained in the HDF5 files are used, but the user can specify `features` in `GridDataset` if not all of them are needed. Since grids features are derived from node and edge features mapped from graphs to grid, the easiest way to see which features are available is to look at the HDF5 file, as explained in detail in `data_generation_ppi.ipynb` and `data_generation_srv.ipynb`, section \"Other tools\".\n", + "- As is the case for a `GraphDataset`, `task` can be assigned to `regress` and `target` to `BA` to perform a regression task. As mentioned previously, we do not provide sample data to perform a regression task for SRVs.\n" ] }, { @@ -539,24 +541,24 @@ "target = \"binary\"\n", "task = \"classif\"\n", "\n", - "print('Loading training data...')\n", + "print(\"Loading training data...\")\n", "dataset_train = GridDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_train.entry), # selects only data points with ids in df_train.entry\n", - " target = target,\n", - " task = task\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_train.entry), # selects only data points with ids in df_train.entry\n", + " target=target,\n", + " task=task,\n", ")\n", - "print('\\nLoading validation data...')\n", + "print(\"\\nLoading validation data...\")\n", "dataset_val = GridDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", - " train_source = dataset_train\n", + " hdf5_path=input_data_path,\n", + " subset=list(df_valid.entry), # selects only data points with ids in df_valid.entry\n", + " train_source=dataset_train,\n", ")\n", - "print('\\nLoading test data...')\n", + "print(\"\\nLoading test data...\")\n", "dataset_test = GridDataset(\n", - " hdf5_path = input_data_path,\n", - " subset = list(df_test.entry), # selects only data points with ids in df_test.entry\n", - " train_source = dataset_train \n", + " hdf5_path=input_data_path,\n", + " subset=list(df_test.entry), # selects only data points with ids in df_test.entry\n", + " train_source=dataset_train,\n", ")" ] }, @@ -567,7 +569,7 @@ "source": [ "#### Trainer\n", "\n", - "As for graphs, the class `Trainer` is used for training, validation and testing of the PyTorch-based CNN. " + "As for graphs, the class `Trainer` is used for training, validation and testing of the PyTorch-based CNN.\n" ] }, { @@ -576,7 +578,7 @@ "metadata": {}, "source": [ "- Also in this case, `neuralnet` can be any neural network class that inherits from `torch.nn.Module`, and it shouldn't be specific to regression or classification in terms of output shape. This tutorial uses `CnnClassification` (implemented in `deeprank2.neuralnets.cnn.model3d`). All CNN architectures already implemented in the pakcage can be found [here](https://github.com/DeepRank/deeprank2/tree/main/deeprank2/neuralnets/cnn) and can be used for training or as a basis for implementing new ones.\n", - "- The rest of the `Trainer` parameters can be used as explained already for graphs." + "- The rest of the `Trainer` parameters can be used as explained already for graphs.\n" ] }, { @@ -584,7 +586,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "##### Training" + "##### Training\n" ] }, { @@ -603,30 +605,31 @@ "min_epoch = 10\n", "\n", "trainer = Trainer(\n", - " neuralnet = CnnClassification,\n", - " dataset_train = dataset_train,\n", - " dataset_val = dataset_val,\n", - " dataset_test = dataset_test,\n", - " output_exporters = [HDF5OutputExporter(os.path.join(output_path, f\"cnn_{task}\"))]\n", + " neuralnet=CnnClassification,\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " output_exporters=[HDF5OutputExporter(os.path.join(output_path, f\"cnn_{task}\"))],\n", ")\n", "\n", "trainer.configure_optimizers(optimizer, lr, weight_decay)\n", "\n", "trainer.train(\n", - " nepoch = epochs,\n", - " batch_size = batch_size,\n", - " earlystop_patience = earlystop_patience,\n", - " earlystop_maxgap = earlystop_maxgap,\n", - " min_epoch = min_epoch,\n", - " validate = True,\n", - " filename = os.path.join(output_path, f\"cnn_{task}\", \"model.pth.tar\"))\n", + " nepoch=epochs,\n", + " batch_size=batch_size,\n", + " earlystop_patience=earlystop_patience,\n", + " earlystop_maxgap=earlystop_maxgap,\n", + " min_epoch=min_epoch,\n", + " validate=True,\n", + " filename=os.path.join(output_path, f\"cnn_{task}\", \"model.pth.tar\"),\n", + ")\n", "\n", "epoch = trainer.epoch_saved_model\n", "print(f\"Model saved at epoch {epoch}\")\n", "pytorch_total_params = sum(p.numel() for p in trainer.model.parameters())\n", - "print(f'Total # of parameters: {pytorch_total_params}')\n", + "print(f\"Total # of parameters: {pytorch_total_params}\")\n", "pytorch_trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)\n", - "print(f'Total # of trainable parameters: {pytorch_trainable_params}')" + "print(f\"Total # of trainable parameters: {pytorch_trainable_params}\")" ] }, { @@ -636,7 +639,7 @@ "source": [ "##### Testing\n", "\n", - "And the trained model can be tested on `dataset_test`:" + "And the trained model can be tested on `dataset_test`:\n" ] }, { @@ -655,7 +658,7 @@ "source": [ "##### Results visualization\n", "\n", - "As for GNNs, the results saved by `HDF5OutputExporter` can be inspected, and are saved in the `data/ppi/cnn_classif/` or `data/srv/cnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`, together with the saved pre-trained model. " + "As for GNNs, the results saved by `HDF5OutputExporter` can be inspected, and are saved in the `data/ppi/cnn_classif/` or `data/srv/cnn_classif/` folder in the form of an HDF5 file, `output_exporter.hdf5`, together with the saved pre-trained model.\n" ] }, { @@ -664,8 +667,12 @@ "metadata": {}, "outputs": [], "source": [ - "output_train = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", - "output_test = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", + "output_train = pd.read_hdf(\n", + " os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n", + ")\n", + "output_test = pd.read_hdf(\n", + " os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n", + ")\n", "output_train.head()" ] }, @@ -674,7 +681,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Also in this case, the loss across the epochs can be plotted for the training and the validation sets:" + "Also in this case, the loss across the epochs can be plotted for the training and the validation sets:\n" ] }, { @@ -683,20 +690,16 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.line(\n", - " output_train,\n", - " x='epoch',\n", - " y='loss',\n", - " color='phase',\n", - " markers=True)\n", + "fig = px.line(output_train, x=\"epoch\", y=\"loss\", color=\"phase\", markers=True)\n", "\n", "fig.add_vline(x=trainer.epoch_saved_model, line_width=3, line_dash=\"dash\", line_color=\"green\")\n", "\n", "fig.update_layout(\n", - " xaxis_title='Epoch #',\n", - " yaxis_title='Loss',\n", - " title='Loss vs epochs - CNN training',\n", - " width=700, height=400,\n", + " xaxis_title=\"Epoch #\",\n", + " yaxis_title=\"Loss\",\n", + " title=\"Loss vs epochs - CNN training\",\n", + " width=700,\n", + " height=400,\n", ")" ] }, @@ -705,7 +708,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And some metrics of interest for classification tasks:" + "And some metrics of interest for classification tasks:\n" ] }, { @@ -716,23 +719,23 @@ "source": [ "threshold = 0.5\n", "df = pd.concat([output_train, output_test])\n", - "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == 'testing'))]\n", + "df_plot = df[(df.epoch == trainer.epoch_saved_model) | ((df.epoch == trainer.epoch_saved_model) & (df.phase == \"testing\"))]\n", "\n", - "for idx, set in enumerate(['training', 'validation', 'testing']):\n", + "for idx, set in enumerate([\"training\", \"validation\", \"testing\"]):\n", " df_plot_phase = df_plot[(df_plot.phase == set)]\n", " y_true = df_plot_phase.target\n", " y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]\n", "\n", - " print(f'\\nMetrics for {set}:')\n", + " print(f\"\\nMetrics for {set}:\")\n", " fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)\n", " auc_score = auc(fpr_roc, tpr_roc)\n", - " print(f'AUC: {round(auc_score, 1)}')\n", - " print(f'Considering a threshold of {threshold}')\n", - " y_pred = (y_score > threshold)*1\n", - " print(f'- Precision: {round(precision_score(y_true, y_pred), 1)}')\n", - " print(f'- Recall: {round(recall_score(y_true, y_pred), 1)}')\n", - " print(f'- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}')\n", - " print(f'- F1: {round(f1_score(y_true, y_pred), 1)}')" + " print(f\"AUC: {round(auc_score, 1)}\")\n", + " print(f\"Considering a threshold of {threshold}\")\n", + " y_pred = (y_score > threshold) * 1\n", + " print(f\"- Precision: {round(precision_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Recall: {round(recall_score(y_true, y_pred), 1)}\")\n", + " print(f\"- Accuracy: {round(accuracy_score(y_true, y_pred), 1)}\")\n", + " print(f\"- F1: {round(f1_score(y_true, y_pred), 1)}\")" ] }, { @@ -740,7 +743,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It's important to note that the dataset used in this analysis is not sufficiently large to provide conclusive and reliable insights. Depending on your specific application, you might find regression, classification, GNNs, and/or CNNs to be valuable options. Feel free to choose the approach that best aligns with your particular problem!" + "It's important to note that the dataset used in this analysis is not sufficiently large to provide conclusive and reliable insights. Depending on your specific application, you might find regression, classification, GNNs, and/or CNNs to be valuable options. Feel free to choose the approach that best aligns with your particular problem!\n" ] } ], @@ -760,7 +763,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" }, "orig_nbformat": 4 },