From 2b69803e20ccfa13c00f8a0f852c64f1c2ebb573 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 15:02:29 +0200 Subject: [PATCH 1/6] refactor: amino acid dictionaries create dictionaties of amino acids by letter, 3-letter code, and name and use these in other modules rather than defining similar structures locally --- deeprank2/domain/aminoacidlist.py | 4 ++++ deeprank2/features/irc.py | 6 +++--- deeprank2/molstruct/aminoacid.py | 3 ++- deeprank2/utils/buildgraph.py | 12 +++-------- deeprank2/utils/parsing/pssm.py | 3 +-- tests/domain/test_aminoacidlist.py | 24 +++++++++++----------- tests/test_query.py | 32 ++++++++++++++--------------- tutorials/data_generation_srv.ipynb | 25 ++++++---------------- 8 files changed, 46 insertions(+), 63 deletions(-) diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index c5727bf3c..9906e2a32 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -350,6 +350,10 @@ # pyrrolysine, ] +amino_acids_by_code = {amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids} +amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids} +amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids} + def convert_aa_nomenclature(aa: str, output_type: int | None = None): try: if len(aa) == 1: diff --git a/deeprank2/features/irc.py b/deeprank2/features/irc.py index acaf40c17..4d7583f9d 100644 --- a/deeprank2/features/irc.py +++ b/deeprank2/features/irc.py @@ -4,7 +4,7 @@ import pdb2sql from deeprank2.domain import nodestorage as Nfeat -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids_by_code from deeprank2.molstruct.aminoacid import Polarity from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.residue import Residue, SingleResidueVariant @@ -67,7 +67,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, for chain1_res, chain2_residues in pdb2sql_contacts.items(): aa1_code = chain1_res[2] try: - aa1 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa1_code][0] + aa1 = amino_acids_by_code[aa1_code] except IndexError: continue # skip keys that are not an amino acid @@ -78,7 +78,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, for chain2_res in chain2_residues: aa2_code = chain2_res[2] try: - aa2 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa2_code][0] + aa2 = amino_acids_by_code[aa2_code] except IndexError: continue # skip keys that are not an amino acid diff --git a/deeprank2/molstruct/aminoacid.py b/deeprank2/molstruct/aminoacid.py index 88c615a87..189b5a564 100644 --- a/deeprank2/molstruct/aminoacid.py +++ b/deeprank2/molstruct/aminoacid.py @@ -1,6 +1,7 @@ from enum import Enum import numpy as np +from numpy.typing import NDArray class Polarity(Enum): @@ -108,7 +109,7 @@ def hydrogen_bond_acceptors(self) -> int: return self._hydrogen_bond_acceptors @property - def onehot(self) -> np.ndarray: + def onehot(self) -> NDArray: if self._index is None: raise ValueError( f"Amino acid {self._name} index is not set, thus no onehot can be computed." diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index 56d873e38..b489934bc 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -5,7 +5,7 @@ from pdb2sql import interface as get_interface from scipy.spatial import distance_matrix -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids_by_code from deeprank2.molstruct.atom import Atom, AtomicElement from deeprank2.molstruct.pair import Pair from deeprank2.molstruct.residue import Residue @@ -27,14 +27,8 @@ def _add_atom_to_residue(atom, residue): residue.add_atom(atom) -_amino_acids_by_code = { - amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids -} - - _elements_by_name = {element.name: element for element in AtomicElement} - def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too-many-arguments, too-many-locals x: float, y: float, z: float, atom_name: str, @@ -73,8 +67,8 @@ def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too- insertion_code = None # The amino acid is only valid when we deal with protein residues. - if residue_name in _amino_acids_by_code: - amino_acid = _amino_acids_by_code[residue_name] + if residue_name in amino_acids_by_code: + amino_acid = amino_acids_by_code[residue_name] else: amino_acid = None diff --git a/deeprank2/utils/parsing/pssm.py b/deeprank2/utils/parsing/pssm.py index d7130bffd..069dc1074 100644 --- a/deeprank2/utils/parsing/pssm.py +++ b/deeprank2/utils/parsing/pssm.py @@ -1,6 +1,6 @@ from typing import TextIO -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids, amino_acids_by_letter from deeprank2.molstruct.residue import Residue from deeprank2.molstruct.structure import Chain from deeprank2.utils.pssmdata import PssmRow, PssmTable @@ -17,7 +17,6 @@ def parse_pssm(file_: TextIO, chain: Chain) -> PssmTable: PssmTable: The position-specific scoring table, parsed from the pssm file. """ - amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids} conservation_rows = {} # Read the pssm header. diff --git a/tests/domain/test_aminoacidlist.py b/tests/domain/test_aminoacidlist.py index 2ab89c50a..7bafb2c89 100644 --- a/tests/domain/test_aminoacidlist.py +++ b/tests/domain/test_aminoacidlist.py @@ -12,15 +12,15 @@ ] def test_all_different_onehot(): - for amino_acid in amino_acids: - for other in amino_acids: - if other != amino_acid: - try: - assert not np.all(amino_acid.onehot == other.onehot) - except AssertionError as e: - if other in EXCEPTIONS[0] and amino_acid in EXCEPTIONS[0]: - assert np.all(amino_acid.onehot == other.onehot) - elif other in EXCEPTIONS[1] and amino_acid in EXCEPTIONS[1]: - assert np.all(amino_acid.onehot == other.onehot) - else: - raise AssertionError(f"one-hot index {amino_acid.index} is occupied by both {amino_acid} and {other}") from e + + for aa1, aa2 in zip(amino_acids, amino_acids): + if aa1 == aa2: + continue + + try: + assert not np.all(aa1.onehot == aa2.onehot) + except AssertionError as e: + if (aa1 in EXCEPTIONS[0] and aa2 in EXCEPTIONS[0]) or (aa1 in EXCEPTIONS[1] and aa2 in EXCEPTIONS[1]): + assert np.all(aa1.onehot == aa2.onehot) + else: + raise AssertionError(f"one-hot index {aa1.index} is occupied by both {aa1} and {aa2}") from e diff --git a/tests/test_query.py b/tests/test_query.py index c50fc558b..3415a30ca 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -7,12 +7,10 @@ import pytest from deeprank2.dataset import GraphDataset, GridDataset +from deeprank2.domain import aminoacidlist as aa from deeprank2.domain import edgestorage as Efeat from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets -from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine, - cysteine, glutamate, glycine, - leucine, lysine, phenylalanine) from deeprank2.features import components, conservation, contact, surfacearea from deeprank2.query import (ProteinProteinInterfaceQuery, QueryCollection, SingleResidueVariantQuery) @@ -146,8 +144,8 @@ def test_variant_graph_101M(): chain_ids="A", variant_residue_number=27, insertion_code=None, - wildtype_amino_acid=asparagine, - variant_amino_acid=phenylalanine, + wildtype_amino_acid=aa.asparagine, + variant_amino_acid=aa.phenylalanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, influence_radius=5.0, @@ -179,8 +177,8 @@ def test_variant_graph_1A0Z(): chain_ids="A", variant_residue_number=125, insertion_code=None, - wildtype_amino_acid=leucine, - variant_amino_acid=arginine, + wildtype_amino_acid=aa.leucine, + variant_amino_acid=aa.arginine, pssm_paths={ "A": "tests/data/pssm/1A0Z/1A0Z.A.pdb.pssm", "B": "tests/data/pssm/1A0Z/1A0Z.B.pdb.pssm", @@ -217,8 +215,8 @@ def test_variant_graph_9API(): chain_ids="A", variant_residue_number=310, insertion_code=None, - wildtype_amino_acid=lysine, - variant_amino_acid=glutamate, + wildtype_amino_acid=aa.lysine, + variant_amino_acid=aa.glutamate, pssm_paths={ "A": "tests/data/pssm/9api/9api.A.pdb.pssm", "B": "tests/data/pssm/9api/9api.B.pdb.pssm", @@ -253,8 +251,8 @@ def test_variant_residue_graph_101M(): chain_ids="A", variant_residue_number=25, insertion_code=None, - wildtype_amino_acid=glycine, - variant_amino_acid=alanine, + wildtype_amino_acid=aa.glycine, + variant_amino_acid=aa.alanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, ) @@ -315,8 +313,8 @@ def test_augmentation(): chain_ids="A", variant_residue_number=25, insertion_code=None, - wildtype_amino_acid=glycine, - variant_amino_acid=alanine, + wildtype_amino_acid=aa.glycine, + variant_amino_acid=aa.alanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, )) @@ -327,8 +325,8 @@ def test_augmentation(): chain_ids="A", variant_residue_number=27, insertion_code=None, - wildtype_amino_acid=asparagine, - variant_amino_acid=phenylalanine, + wildtype_amino_acid=aa.asparagine, + variant_amino_acid=aa.phenylalanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, influence_radius=3.0, @@ -470,8 +468,8 @@ def test_variant_query_multiple_chains(): chain_ids = "A", variant_residue_number = 14, insertion_code = None, - wildtype_amino_acid = arginine, - variant_amino_acid = cysteine, + wildtype_amino_acid = aa.arginine, + variant_amino_acid = aa.cysteine, pssm_paths = {"A": "tests/data/pssm/2g98/2g98.A.pdb.pssm"}, targets = {targets.BINARY: 1}, influence_radius = 10.0, diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index 2e7be5c8d..f8c3ffddf 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -72,23 +72,10 @@ "import matplotlib.pyplot as plt\n", "from deeprank2.query import QueryCollection\n", "from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n", - "from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine,\n", - " serine, glycine, leucine, aspartate,\n", - " glutamine, glutamate, lysine, phenylalanine, histidine,\n", - " tyrosine, tryptophan, valine, proline,\n", - " cysteine, isoleucine, methionine, threonine)\n", + "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", "from deeprank2.features import components, contact\n", "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset\n", - "\n", - "aa_dict = {\"ALA\": alanine, \"CYS\": cysteine, \"ASP\": aspartate,\n", - " \"GLU\": glutamate, \"PHE\": phenylalanine, \"GLY\": glycine, \n", - " \"HIS\": histidine, \"ILE\": isoleucine, \"LYS\": lysine,\n", - " \"LEU\": leucine, \"MET\": methionine, \"ASN\": asparagine,\n", - " \"PRO\": proline, \"GLN\": glutamine, \"ARG\": arginine,\n", - " \"SER\": serine, \"THR\": threonine, \"VAL\": valine,\n", - " \"TRP\": tryptophan, \"TYR\": tyrosine\n", - " }" + "from deeprank2.dataset import GraphDataset\n" ] }, { @@ -220,8 +207,8 @@ "\t\tchain_ids = \"A\",\n", "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", + "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", + "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", "\t\ttargets = {'binary': targets[i]},\n", "\t\tinfluence_radius = influence_radius,\n", "\t\tmax_edge_length = max_edge_length,\n", @@ -464,8 +451,8 @@ "\t\tchain_ids = \"A\",\n", "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", + "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", + "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", "\t\ttargets = {'binary': targets[i]},\n", "\t\tinfluence_radius = influence_radius,\n", "\t\tmax_edge_length = max_edge_length,\n", From 0d994de6b262386566361f33520a34faf99a2237 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 16:36:20 +0200 Subject: [PATCH 2/6] docs: improve code documentation of utils/buildgraph.py type hinting, docstrings, code comments, excessive white lines, etc --- deeprank2/utils/buildgraph.py | 84 ++++++++++++----------------------- 1 file changed, 28 insertions(+), 56 deletions(-) diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index b489934bc..c7adaafbc 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -2,7 +2,8 @@ import os import numpy as np -from pdb2sql import interface as get_interface +from pdb2sql import interface as pdb2sql_interface +from pdb2sql import pdb2sql as pdb2sql_object from scipy.spatial import distance_matrix from deeprank2.domain.aminoacidlist import amino_acids_by_code @@ -14,21 +15,21 @@ _log = logging.getLogger(__name__) -def _add_atom_to_residue(atom, residue): +def _add_atom_to_residue(atom: Atom, residue: Residue): + """Adds an `Atom` to a `Residue` if not already there. + + If no matching atom is found, add the current atom to the residue. + If there's another atom with the same name, choose the one with the highest occupancy. + """ + for other_atom in residue.atoms: if other_atom.name == atom.name: - # Don't allow two atoms with the same name, pick the highest - # occupancy if other_atom.occupancy < atom.occupancy: other_atom.change_altloc(atom) return - - # not there yet, add it residue.add_atom(atom) -_elements_by_name = {element.name: element for element in AtomicElement} - def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too-many-arguments, too-many-locals x: float, y: float, z: float, atom_name: str, @@ -38,7 +39,6 @@ def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too- residue_number: int, residue_name: str, insertion_code: str): - """ This is a subroutine, to be used in other methods for converting pdb2sql atomic data into a deeprank structure object. It should be called for one atom. @@ -62,55 +62,34 @@ def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too- if altloc is not None and altloc != "" and altloc != "A": return - # We use None to indicate that the residue has no insertion code. - if insertion_code == "": - insertion_code = None - - # The amino acid is only valid when we deal with protein residues. - if residue_name in amino_acids_by_code: - amino_acid = amino_acids_by_code[residue_name] - else: - amino_acid = None - - # Turn the x,y,z into a vector: + insertion_code = None if insertion_code == "" else insertion_code + amino_acid = amino_acids_by_code[residue_name] if residue_name in amino_acids_by_code else None atom_position = np.array([x, y, z]) - # Init chain. if not structure.has_chain(chain_id): + structure.add_chain(Chain(structure, chain_id)) + chain = structure.get_chain(chain_id) - chain = Chain(structure, chain_id) - structure.add_chain(chain) - else: - chain = structure.get_chain(chain_id) - - # Init residue. if not chain.has_residue(residue_number, insertion_code): + chain.add_residue(Residue(chain, residue_number, amino_acid, insertion_code)) + residue = chain.get_residue(residue_number, insertion_code) - residue = Residue(chain, residue_number, amino_acid, insertion_code) - chain.add_residue(residue) - else: - residue = chain.get_residue(residue_number, insertion_code) - - # Init atom. atom = Atom( - residue, atom_name, _elements_by_name[element_name], atom_position, occupancy + residue, atom_name, AtomicElement[element_name], atom_position, occupancy ) _add_atom_to_residue(atom, residue) -def get_structure(pdb, id_: str) -> PDBStructure: +def get_structure(pdb: pdb2sql_object, id_: str) -> PDBStructure: """Builds a structure from rows in a pdb file. Args: pdb (pdb2sql object): The pdb structure that we're investigating. - id (str): Unique id for the pdb structure. + id_ (str): Unique id for the pdb structure. Returns: PDBStructure: The structure object, giving access to chains, residues, atoms. """ - - # We need these intermediary dicts to keep track of which residues and - # chains have already been created. structure = PDBStructure(id_) # Iterate over the atom output from pdb2sql @@ -153,7 +132,7 @@ def get_contact_atoms( # pylint: disable=too-many-locals ) -> list[Atom]: """Gets the contact atoms from pdb2sql and wraps them in python objects.""" - interface = get_interface(pdb_path) + interface = pdb2sql_interface(pdb_path) try: atom_indexes = interface.get_contact_atoms( cutoff=influence_radius, @@ -218,7 +197,7 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals """ # Find out which residues are pairs - interface = get_interface(pdb_path) + interface = pdb2sql_interface(pdb_path) try: contact_residues = interface.get_contact_residues( cutoff=influence_radius, @@ -283,30 +262,23 @@ def get_surrounding_residues( """Get the residues that lie within a radius around a residue. Args: - structure (Union[:class:`Chain`, :class:`PDBStructure`]): The structure to take residues from. + structure (:class:`Chain` | :class:`PDBStructure`): The structure to take residues from. residue (:class:`Residue`): The residue in the structure. radius (float): Max distance in Ångström between atoms of the residue and the other residues. Returns: - (a set of deeprank residues): The surrounding residues. + list[:class:`Residue`]: The surrounding residues. """ structure_atoms = structure.get_atoms() - residue_atoms = residue.atoms - structure_atom_positions = [atom.position for atom in structure_atoms] - residue_atom_positions = [atom.position for atom in residue_atoms] - - distances = distance_matrix(structure_atom_positions, residue_atom_positions, p=2) - - close_residues = set([]) + residue_atom_positions = [atom.position for atom in residue.atoms] + pairwise_distances = distance_matrix(structure_atom_positions, residue_atom_positions, p=2) + surrounding_residues = set([]) for structure_atom_index, structure_atom in enumerate(structure_atoms): - - shortest_distance = np.min(distances[structure_atom_index, :]) - + shortest_distance = np.min(pairwise_distances[structure_atom_index, :]) if shortest_distance < radius: + surrounding_residues.add(structure_atom.residue) - close_residues.add(structure_atom.residue) - - return list(close_residues) + return list(surrounding_residues) From 60775bd95f471e31de55658a9d8949e19e731b56 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 21:29:06 +0200 Subject: [PATCH 3/6] refactor: reading atom data from pdb2sql object --- deeprank2/utils/buildgraph.py | 161 +++++++++++----------------------- 1 file changed, 52 insertions(+), 109 deletions(-) diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index c7adaafbc..a8c213000 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -30,57 +30,57 @@ def _add_atom_to_residue(atom: Atom, residue: Residue): residue.add_atom(atom) -def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too-many-arguments, too-many-locals - x: float, y: float, z: float, - atom_name: str, - altloc: str, occupancy: float, - element_name: str, - chain_id: str, - residue_number: int, - residue_name: str, - insertion_code: str): - """ - This is a subroutine, to be used in other methods for converting pdb2sql atomic data into a - deeprank structure object. It should be called for one atom. +def _add_atom_data_to_structure( + structure: PDBStructure, + pdb_obj: pdb2sql_object, + **kwargs +): + """This subroutine retrieves pdb2sql atomic data for `PDBStructure` objects as defined in DeepRank2. + + This function should be called for one atom at a time. Args: - structure (:class:`PDBStructure`): Where this atom should be added to. - x (float): x-coordinate of atom. - y (float): y-coordinate of atom. - z (float): z-coordinate of atom. - atom_name (str): Name of atom: 'CA', 'C', 'N', 'O', 'CB', etc. - altloc (str): Pdb alternative location id for this atom (can be empty): 'A', 'B', 'C', etc. - occupancy (float): Pdb occupancy of this atom, ranging from 0.0 to 1.0. Should be used with altloc. - element_name (str): Pdb element symbol of this atom: 'C', 'O', 'H', 'N', 'S'. - chain_id (str): Pdb chain identifier: 'A', 'B', 'C', etc. - residue_number (int): Pdb residue number, a positive integer. - residue_name (str): Pdb residue name: "ALA", "CYS", "ASP", etc. - insertion_code (str): Pdb residue insertion code (can be empty) : '', 'A', 'B', 'C', etc. + structure (:class:`PDBStructure`): The structure to which this atom should be added to. + pdb (pdb2sql_object): The `pdb2sql` object to retrieve the data from. + kwargs: as required by the get function for the `pdb2sql` object. """ - # Make sure not to take the same atom twice. - if altloc is not None and altloc != "" and altloc != "A": - return - - insertion_code = None if insertion_code == "" else insertion_code - amino_acid = amino_acids_by_code[residue_name] if residue_name in amino_acids_by_code else None - atom_position = np.array([x, y, z]) - - if not structure.has_chain(chain_id): - structure.add_chain(Chain(structure, chain_id)) - chain = structure.get_chain(chain_id) - - if not chain.has_residue(residue_number, insertion_code): - chain.add_residue(Residue(chain, residue_number, amino_acid, insertion_code)) - residue = chain.get_residue(residue_number, insertion_code) - - atom = Atom( - residue, atom_name, AtomicElement[element_name], atom_position, occupancy - ) - _add_atom_to_residue(atom, residue) + pdb2sql_columns = "x,y,z,name,altLoc,occ,element,chainID,resSeq,resName,iCode" + data_keys = pdb2sql_columns.split(sep=',') + for data_values in pdb_obj.get(pdb2sql_columns, **kwargs): + atom_data = dict(zip(data_keys, data_values)) + + # exit function if this atom is already part of the structure + if atom_data["altLoc"] not in (None, "", "A"): + return + + atom_data["iCode"] = None if atom_data["iCode"] == "" else atom_data["iCode"] + + try: + atom_data["aa"] = amino_acids_by_code[atom_data["resName"]] + except KeyError: + atom_data["aa"] = None + atom_data["coordinates"] = np.array(data_values[:3]) + + if not structure.has_chain(atom_data["chainID"]): + structure.add_chain(Chain(structure, atom_data["chainID"])) + chain = structure.get_chain(atom_data["chainID"]) + + if not chain.has_residue(atom_data["resSeq"], atom_data["iCode"]): + chain.add_residue(Residue(chain, atom_data["resSeq"], atom_data["aa"], atom_data["iCode"])) + residue = chain.get_residue(atom_data["resSeq"], atom_data["iCode"]) + + atom = Atom( + residue, + atom_data["name"], + AtomicElement[atom_data["element"]], + atom_data["coordinates"], + atom_data["occ"], + ) + _add_atom_to_residue(atom, residue) -def get_structure(pdb: pdb2sql_object, id_: str) -> PDBStructure: +def get_structure(pdb_obj: pdb2sql_object, id_: str) -> PDBStructure: """Builds a structure from rows in a pdb file. Args: @@ -91,41 +91,11 @@ def get_structure(pdb: pdb2sql_object, id_: str) -> PDBStructure: PDBStructure: The structure object, giving access to chains, residues, atoms. """ structure = PDBStructure(id_) - - # Iterate over the atom output from pdb2sql - for row in pdb.get( - "x,y,z,rowID,name,altLoc,occ,element,chainID,resSeq,resName,iCode", model=0 - ): - - ( - x, - y, - z, - _, - atom_name, - altloc, - occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code, - ) = row - - _add_atom_data_to_structure(structure, - x, y, z, - atom_name, - altloc, occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code) - + _add_atom_data_to_structure(structure, pdb_obj, model=0) return structure -def get_contact_atoms( # pylint: disable=too-many-locals +def get_contact_atoms( pdb_path: str, chain_ids: list[str], influence_radius: float @@ -133,47 +103,20 @@ def get_contact_atoms( # pylint: disable=too-many-locals """Gets the contact atoms from pdb2sql and wraps them in python objects.""" interface = pdb2sql_interface(pdb_path) + pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] + structure = PDBStructure(f"contact_atoms_{pdb_name}") + try: atom_indexes = interface.get_contact_atoms( cutoff=influence_radius, chain1=chain_ids[0], chain2=chain_ids[1], ) - rows = interface.get( - "x,y,z,name,element,altLoc,occ,chainID,resSeq,resName,iCode", - rowID=atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] - ) + pdb_rowID = atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] + _add_atom_data_to_structure(structure, interface, rowID=pdb_rowID) finally: interface._close() # pylint: disable=protected-access - pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] - structure = PDBStructure(f"contact_atoms_{pdb_name}") - - for row in rows: - ( - x, - y, - z, - atom_name, - element_name, - altloc, - occupancy, - chain_id, - residue_number, - residue_name, - insertion_code - ) = row - - _add_atom_data_to_structure(structure, - x, y, z, - atom_name, - altloc, occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code) - return structure.get_atoms() def get_residue_contact_pairs( # pylint: disable=too-many-locals From da0f510ea29160967e261ae8a48720eaf351e2c8 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 21:30:10 +0200 Subject: [PATCH 4/6] refactor: finding residue from pdb2sql key by using a helper function rather than repeating code --- deeprank2/utils/buildgraph.py | 67 ++++++++++++++--------------------- 1 file changed, 27 insertions(+), 40 deletions(-) diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index a8c213000..83511d05a 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -119,7 +119,8 @@ def get_contact_atoms( return structure.get_atoms() -def get_residue_contact_pairs( # pylint: disable=too-many-locals + +def get_residue_contact_pairs( pdb_path: str, structure: PDBStructure, chain_id1: str, @@ -153,50 +154,36 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals # Map to residue objects residue_pairs = set([]) - for residue_key1, _ in contact_residues.items(): - residue_chain_id1, residue_number1, residue_name1 = residue_key1 - - chain1 = structure.get_chain(residue_chain_id1) - - residue1 = None - for residue in chain1.residues: - if ( - residue.number == residue_number1 - and residue.amino_acid is not None - and residue.amino_acid.three_letter_code == residue_name1 - ): - residue1 = residue - break - else: - raise ValueError( - f"Not found: {pdb_path} {residue_chain_id1} {residue_number1} {residue_name1}" - ) - - for residue_chain_id2, residue_number2, residue_name2 in contact_residues[ # pylint: disable=unnecessary-dict-index-lookup - residue_key1 - ]: - - chain2 = structure.get_chain(residue_chain_id2) - - residue2 = None - for residue in chain2.residues: - if ( - residue.number == residue_number2 - and residue.amino_acid is not None - and residue.amino_acid.three_letter_code == residue_name2 - ): - residue2 = residue - break - else: - raise ValueError( - f"Not found: {pdb_path} {residue_chain_id2} {residue_number2} {residue_name2}" - ) - + for residue_key1, residue_contacts in contact_residues.items(): + residue1 = _get_residue_from_key(structure, residue_key1) + for residue_key2 in residue_contacts: + residue2 = _get_residue_from_key(structure, residue_key2) residue_pairs.add(Pair(residue1, residue2)) return residue_pairs +def _get_residue_from_key( + structure: PDBStructure, + residue_key: tuple[str, int, str], +) -> Residue: + """Returns a residue object given a pdb2sql-formatted residue key.""" + + residue_chain_id, residue_number, residue_name = residue_key + chain = structure.get_chain(residue_chain_id) + + for residue in chain.residues: + if ( + residue.number == residue_number + and residue.amino_acid is not None + and residue.amino_acid.three_letter_code == residue_name + ): + return residue + raise ValueError( + f"Residue ({residue_key}) not found in {structure.id}." + ) + + def get_surrounding_residues( structure: Chain | PDBStructure, residue: Residue, From 2708c5cd08a642f113b7c61f5ef6fdb34da9b590 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 22:22:06 +0200 Subject: [PATCH 5/6] style: make build_atomic_graph and build_residue_graph look similar this is a preparation for the next commit, where the functions are unified. here only variables are renamed and spacing is changed so that it is easy to see what the similarities and differences between the functions are. --- deeprank2/molstruct/residue.py | 2 +- deeprank2/utils/graph.py | 56 +++++++++++++++------------------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index e9938b219..552d1a015 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -97,7 +97,7 @@ def __repr__(self) -> str: @property def position(self) -> np.array: - return np.mean([atom.position for atom in self._atoms], axis=0) + return self.get_center() def get_center(self) -> NDArray: """Find the center position of a `Residue`. diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 4ab0546e1..593d0dab0 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -59,7 +59,6 @@ def __init__(self, id_: Atom | Residue): raise TypeError(type(id_)) self.id = id_ - self.features = {} @property @@ -319,8 +318,10 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( # pylint: disable=too-many-locals - atoms: list[Atom], graph_id: str, max_edge_length: float +def build_atomic_graph( + atoms: list[Atom], + graph_id: str, + max_edge_length: float, ) -> Graph: """Builds a graph, using the atoms as nodes. @@ -333,19 +334,18 @@ def build_atomic_graph( # pylint: disable=too-many-locals distances = distance_matrix(positions, positions, p=2) neighbours = distances < max_edge_length + atom_index_pairs = np.transpose(np.nonzero(neighbours)) graph = Graph(graph_id) - for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)): - if atom1_index != atom2_index: + for index1, index2 in atom_index_pairs: + if index1 != index2: - atom1 = atoms[atom1_index] - atom2 = atoms[atom2_index] - contact = AtomicContact(atom1, atom2) + node1 = Node(atoms[index1]) + node2 = Node(atoms[index2]) + contact = AtomicContact(node1.id, node2.id) - node1 = Node(atom1) - node2 = Node(atom2) - node1.features[Nfeat.POSITION] = atom1.position - node2.features[Nfeat.POSITION] = atom2.position + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position graph.add_node(node1) graph.add_node(node2) @@ -355,11 +355,13 @@ def build_atomic_graph( # pylint: disable=too-many-locals def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], graph_id: str, max_edge_length: float + residues: list[Residue], + graph_id: str, + max_edge_length: float, ) -> Graph: """Builds a graph, using the residues as nodes. - The max edge distance is in Ångströms. + The max edge length is in Ångströms. It's the shortest interatomic distance between two residues. """ @@ -379,8 +381,6 @@ def build_residue_graph( # pylint: disable=too-many-locals positions[atom_index] = atom.position distances = distance_matrix(positions, positions, p=2) - - # determine which atoms are close enough neighbours = distances < max_edge_length atom_index_pairs = np.transpose(np.nonzero(neighbours)) @@ -390,26 +390,20 @@ def build_residue_graph( # pylint: disable=too-many-locals # build the graph graph = Graph(graph_id) - for residue1_index, residue2_index in residue_index_pairs: + for index1, index2 in residue_index_pairs: + if index1 != index2: - residue1: Residue = residues[residue1_index] - residue2: Residue = residues[residue2_index] + node1 = Node(residues[index1]) + node2 = Node(residues[index2]) + contact = ResidueContact(node1.id, node2.id) - if residue1 != residue2: + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - contact = ResidueContact(residue1, residue2) - - node1 = Node(residue1) - node2 = Node(residue2) - edge = Edge(contact) - - node1.features[Nfeat.POSITION] = residue1.get_center() - node2.features[Nfeat.POSITION] = residue2.get_center() - - # The same residue will be added multiple times as a node, + # The same residue will be added multiple times as a node, # but the Graph class fixes this. graph.add_node(node1) graph.add_node(node2) - graph.add_edge(edge) + graph.add_edge(Edge(contact)) return graph From a777cb761c76ef68ed10c071a48468d39a800ad1 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 23:46:52 +0200 Subject: [PATCH 6/6] refactor: unify build_graph functions previously separate functions for building atom and residue graphs --- deeprank2/molstruct/pair.py | 2 +- deeprank2/query.py | 11 ++- deeprank2/utils/graph.py | 139 +++++++++++++++--------------------- tests/features/__init__.py | 61 ++++++++-------- 4 files changed, 93 insertions(+), 120 deletions(-) diff --git a/deeprank2/molstruct/pair.py b/deeprank2/molstruct/pair.py index ecdb5febf..bdf7f5423 100644 --- a/deeprank2/molstruct/pair.py +++ b/deeprank2/molstruct/pair.py @@ -41,7 +41,7 @@ def __repr__(self) -> str: class Contact(Pair, ABC): - """Parent class to bind `ResidueContact` and `ResidueContact` objects.""" + """Parent class to bind `ResidueContact` and `AtomicContact` objects.""" class ResidueContact(Contact): diff --git a/deeprank2/query.py b/deeprank2/query.py index 341d348a0..f0fefd113 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -25,8 +25,7 @@ from deeprank2.molstruct.structure import PDBStructure from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod from deeprank2.utils.parsing.pssm import parse_pssm @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'residue': - graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length) elif self.resolution == 'atom': residues.append(variant_residue) atoms = set([]) @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph: atoms.add(atom) atoms = list(atoms) - graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length) else: raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'atom': - graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length) elif self.resolution == 'residue': residues_selected = list({atom.residue for atom in contact_atoms}) - graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 593d0dab0..fafa7519d 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os from typing import Callable @@ -318,92 +320,67 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( - atoms: list[Atom], - graph_id: str, - max_edge_length: float, -) -> Graph: - """Builds a graph, using the atoms as nodes. - - The max edge distance is in Ångströms. - """ - - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position - - distances = distance_matrix(positions, positions, p=2) - neighbours = distances < max_edge_length - atom_index_pairs = np.transpose(np.nonzero(neighbours)) - - graph = Graph(graph_id) - for index1, index2 in atom_index_pairs: - if index1 != index2: - - node1 = Node(atoms[index1]) - node2 = Node(atoms[index2]) - contact = AtomicContact(node1.id, node2.id) - - node1.features[Nfeat.POSITION] = node1.id.position - node2.features[Nfeat.POSITION] = node2.id.position - - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(Edge(contact)) - - return graph - - -def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], - graph_id: str, - max_edge_length: float, -) -> Graph: - """Builds a graph, using the residues as nodes. - - The max edge length is in Ångströms. - It's the shortest interatomic distance between two residues. - """ - - # collect the set of atoms and remember which are on the same residue (by index) - atoms = [] - atoms_residues = [] - for residue_index, residue in enumerate(residues): - for atom in residue.atoms: - atoms.append(atom) - atoms_residues.append(residue_index) - - atoms_residues = np.array(atoms_residues) - - # calculate the distance matrix - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position + @staticmethod + def build_graph( # pylint: disable=too-many-locals + nodes: list[Atom] | list[Residue], + graph_id: str, + max_edge_length: float, + ) -> Graph: + """Builds a graph. + + Args: + nodes (list[Atom] | list[Residue]): List of `Atom`s or `Residue`s to include in graph. + All nodes must be of same type. + graph_id (str): Human readable identifier for graph. + max_edge_length (float): Maximum distance between two nodes to connect them with an edge. + + Returns: + Graph: Containing nodes (with positions) and edges. + + Raises: + TypeError: if `nodes` argument contains a mix of different types. + """ + + if all(isinstance(node, Atom) for node in nodes): + atoms = nodes + NodeContact = AtomicContact + elif all(isinstance(node, Residue) for node in nodes): + # collect the set of atoms and remember which are on the same residue (by index) + atoms = [] + atoms_residues = [] + for residue_index, residue in enumerate(nodes): + for atom in residue.atoms: + atoms.append(atom) + atoms_residues.append(residue_index) + atoms_residues = np.array(atoms_residues) + NodeContact = ResidueContact + else: + raise TypeError("All nodes in the graph must be of the same type.") - distances = distance_matrix(positions, positions, p=2) - neighbours = distances < max_edge_length + positions = np.empty((len(atoms), 3)) + for atom_index, atom in enumerate(atoms): + positions[atom_index] = atom.position + neighbours = max_edge_length > distance_matrix(positions, positions, p=2) - atom_index_pairs = np.transpose(np.nonzero(neighbours)) + index_pairs = np.transpose(np.nonzero(neighbours)) # atom pairs + if NodeContact == ResidueContact: + index_pairs = np.unique(atoms_residues[index_pairs], axis=0) # residue pairs - # point out the unique residues for the atom pairs - residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0) + graph = Graph(graph_id) - # build the graph - graph = Graph(graph_id) - for index1, index2 in residue_index_pairs: - if index1 != index2: + for index1, index2 in index_pairs: + if index1 != index2: - node1 = Node(residues[index1]) - node2 = Node(residues[index2]) - contact = ResidueContact(node1.id, node2.id) + node1 = Node(nodes[index1]) + node2 = Node(nodes[index2]) + contact = NodeContact(node1.id, node2.id) - node1.features[Nfeat.POSITION] = node1.id.position - node2.features[Nfeat.POSITION] = node2.id.position + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - # The same residue will be added multiple times as a node, - # but the Graph class fixes this. - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(Edge(contact)) + # The same node will be added multiple times, but the Graph class fixes this. + graph.add_node(node1) + graph.add_node(node2) + graph.add_edge(Edge(contact)) - return graph + return graph diff --git a/tests/features/__init__.py b/tests/features/__init__.py index eace0fbaf..c5c60aa90 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -9,8 +9,7 @@ from deeprank2.utils.buildgraph import (get_residue_contact_pairs, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.parsing.pssm import parse_pssm @@ -25,7 +24,7 @@ def _get_residue(chain: Chain, number: int) -> Residue: def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001 pdb_path: str, - detail: Literal['atomic', 'residue'], + detail: Literal['atom', 'residue'], influence_radius: float, max_edge_length: float, central_res: int | None = None, @@ -37,11 +36,11 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq Args: pdb_path (str): Path of pdb file. - detail (Literal['atomic', 'residue']): Level of detail. + detail (Literal['atom', 'residue']): Type of graph to create. influence_radius (float): max distance to include in graph. max_edge_length (float): max distance to create an edge. central_res (int | None, optional): Residue to center a single-chain graph around. - Use None to create a 2-chain graph, or any value for a single-chain graph + Use None to create a 2-chain graph, or any value for a single-chain graph. Defaults to None. variant (AminoAcid | None, optional): Amino acid to use as a variant amino acid. Defaults to None. @@ -52,7 +51,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq TypeError: if detail is set to anything other than 'residue' or 'atom' Returns: - Graph: As generated by build_residue_graph or build_atomic_graph + Graph: As generated by Graph.build_graph SingleResidueVariant: returns None if central_res is None """ @@ -62,12 +61,13 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq finally: pdb._close() # pylint: disable=protected-access - if not central_res: # pylint: disable=no-else-raise + if not central_res: nodes = set([]) if not chain_ids: chains = (structure.chains[0].id, structure.chains[1].id) else: chains = [structure.get_chain(chain_id) for chain_id in chain_ids] + for residue1, residue2 in get_residue_contact_pairs( pdb_path, structure, chains[0], chains[1], @@ -76,36 +76,33 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq if detail == 'residue': nodes.add(residue1) nodes.add(residue2) - elif detail == 'atom': for atom in residue1.atoms: nodes.add(atom) for atom in residue2.atoms: nodes.add(atom) + else: + raise TypeError('detail must be "atom" or "residue"') - if detail == 'residue': - return build_residue_graph(list(nodes), structure.id, max_edge_length), None - if detail == 'atom': - return build_atomic_graph(list(nodes), structure.id, max_edge_length), None - raise TypeError('detail must be "atom" or "residue"') + return Graph.build_graph(list(nodes), structure.id, max_edge_length), None + # if central_res + if not chain_ids: + chain: Chain = structure.chains[0] else: - if not chain_ids: - chain: Chain = structure.chains[0] - else: - chain = structure.get_chain(chain_ids) - residue = _get_residue(chain, central_res) - surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) - - try: - with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) - except FileNotFoundError: - pass - - if detail == 'residue': - return build_residue_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) - if detail == 'atom': - atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) - return build_atomic_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) - raise TypeError('detail must be "atom" or "residue"') + chain = structure.get_chain(chain_ids) + residue = _get_residue(chain, central_res) + surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) + + try: + with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: + chain.pssm = parse_pssm(f, chain) + except FileNotFoundError: + pass + + if detail == 'residue': + return Graph.build_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) + if detail == 'atom': + atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) + return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) + raise TypeError('detail must be "atom" or "residue"')