From 8b9bd3af3e03cde14eaec5b97760a47eec53fb82 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 23:46:52 +0200 Subject: [PATCH] refactor: unify build_graph functions previously separate functions for building atom and residue graphs --- deeprank2/molstruct/pair.py | 2 +- deeprank2/query.py | 11 ++- deeprank2/utils/graph.py | 136 +++++++++++++++--------------------- tests/features/__init__.py | 61 ++++++++-------- 4 files changed, 90 insertions(+), 120 deletions(-) diff --git a/deeprank2/molstruct/pair.py b/deeprank2/molstruct/pair.py index ecdb5febf..bdf7f5423 100644 --- a/deeprank2/molstruct/pair.py +++ b/deeprank2/molstruct/pair.py @@ -41,7 +41,7 @@ def __repr__(self) -> str: class Contact(Pair, ABC): - """Parent class to bind `ResidueContact` and `ResidueContact` objects.""" + """Parent class to bind `ResidueContact` and `AtomicContact` objects.""" class ResidueContact(Contact): diff --git a/deeprank2/query.py b/deeprank2/query.py index 341d348a0..f0fefd113 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -25,8 +25,7 @@ from deeprank2.molstruct.structure import PDBStructure from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod from deeprank2.utils.parsing.pssm import parse_pssm @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'residue': - graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length) elif self.resolution == 'atom': residues.append(variant_residue) atoms = set([]) @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph: atoms.add(atom) atoms = list(atoms) - graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length) else: raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'atom': - graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length) elif self.resolution == 'residue': residues_selected = list({atom.residue for atom in contact_atoms}) - graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 593d0dab0..0273293cf 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os from typing import Callable @@ -318,92 +320,64 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( - atoms: list[Atom], - graph_id: str, - max_edge_length: float, -) -> Graph: - """Builds a graph, using the atoms as nodes. - - The max edge distance is in Ångströms. - """ - - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position - - distances = distance_matrix(positions, positions, p=2) - neighbours = distances < max_edge_length - atom_index_pairs = np.transpose(np.nonzero(neighbours)) - - graph = Graph(graph_id) - for index1, index2 in atom_index_pairs: - if index1 != index2: - - node1 = Node(atoms[index1]) - node2 = Node(atoms[index2]) - contact = AtomicContact(node1.id, node2.id) - - node1.features[Nfeat.POSITION] = node1.id.position - node2.features[Nfeat.POSITION] = node2.id.position - - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(Edge(contact)) - - return graph - - -def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], - graph_id: str, - max_edge_length: float, -) -> Graph: - """Builds a graph, using the residues as nodes. - - The max edge length is in Ångströms. - It's the shortest interatomic distance between two residues. - """ - - # collect the set of atoms and remember which are on the same residue (by index) - atoms = [] - atoms_residues = [] - for residue_index, residue in enumerate(residues): - for atom in residue.atoms: - atoms.append(atom) - atoms_residues.append(residue_index) - - atoms_residues = np.array(atoms_residues) - - # calculate the distance matrix - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position + @staticmethod + def build_graph( # pylint: disable=too-many-locals + nodes: list[Atom] | list[Residue], + graph_id: str, + max_edge_length: float, + ) -> Graph: + """Builds a graph. + + Args: + nodes (list[Atom] | list[Residue]): List of `Atoms` or `Residues` to include in graph. + All nodes must be of same type. + graph_id (str): Human readible identifier for graph. + max_edge_length (float): Maximum distance between two nodes to connect them with an edge. + + Returns: + Graph: Containing nodes (with positions) and edges. + """ + + if all(isinstance(node, Atom) for node in nodes): + atoms = nodes + NodeContact = AtomicContact + elif all(isinstance(node, Residue) for node in nodes): + # collect the set of atoms and remember which are on the same residue (by index) + atoms = [] + atoms_residues = [] + for residue_index, residue in enumerate(nodes): + for atom in residue.atoms: + atoms.append(atom) + atoms_residues.append(residue_index) + atoms_residues = np.array(atoms_residues) + NodeContact = ResidueContact + else: + raise TypeError("All nodes in the graph must be of the same type.") - distances = distance_matrix(positions, positions, p=2) - neighbours = distances < max_edge_length + positions = np.empty((len(atoms), 3)) + for atom_index, atom in enumerate(atoms): + positions[atom_index] = atom.position + neighbours = max_edge_length > distance_matrix(positions, positions, p=2) - atom_index_pairs = np.transpose(np.nonzero(neighbours)) + index_pairs = np.transpose(np.nonzero(neighbours)) # atom pairs + if NodeContact == ResidueContact: + index_pairs = np.unique(atoms_residues[index_pairs], axis=0) # residue pairs - # point out the unique residues for the atom pairs - residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0) + graph = Graph(graph_id) - # build the graph - graph = Graph(graph_id) - for index1, index2 in residue_index_pairs: - if index1 != index2: + for index1, index2 in index_pairs: + if index1 != index2: - node1 = Node(residues[index1]) - node2 = Node(residues[index2]) - contact = ResidueContact(node1.id, node2.id) + node1 = Node(nodes[index1]) + node2 = Node(nodes[index2]) + contact = NodeContact(node1.id, node2.id) - node1.features[Nfeat.POSITION] = node1.id.position - node2.features[Nfeat.POSITION] = node2.id.position + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - # The same residue will be added multiple times as a node, - # but the Graph class fixes this. - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(Edge(contact)) + # The same node will be added multiple times, but the Graph class fixes this. + graph.add_node(node1) + graph.add_node(node2) + graph.add_edge(Edge(contact)) - return graph + return graph diff --git a/tests/features/__init__.py b/tests/features/__init__.py index eace0fbaf..c5c60aa90 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -9,8 +9,7 @@ from deeprank2.utils.buildgraph import (get_residue_contact_pairs, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.parsing.pssm import parse_pssm @@ -25,7 +24,7 @@ def _get_residue(chain: Chain, number: int) -> Residue: def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001 pdb_path: str, - detail: Literal['atomic', 'residue'], + detail: Literal['atom', 'residue'], influence_radius: float, max_edge_length: float, central_res: int | None = None, @@ -37,11 +36,11 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq Args: pdb_path (str): Path of pdb file. - detail (Literal['atomic', 'residue']): Level of detail. + detail (Literal['atom', 'residue']): Type of graph to create. influence_radius (float): max distance to include in graph. max_edge_length (float): max distance to create an edge. central_res (int | None, optional): Residue to center a single-chain graph around. - Use None to create a 2-chain graph, or any value for a single-chain graph + Use None to create a 2-chain graph, or any value for a single-chain graph. Defaults to None. variant (AminoAcid | None, optional): Amino acid to use as a variant amino acid. Defaults to None. @@ -52,7 +51,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq TypeError: if detail is set to anything other than 'residue' or 'atom' Returns: - Graph: As generated by build_residue_graph or build_atomic_graph + Graph: As generated by Graph.build_graph SingleResidueVariant: returns None if central_res is None """ @@ -62,12 +61,13 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq finally: pdb._close() # pylint: disable=protected-access - if not central_res: # pylint: disable=no-else-raise + if not central_res: nodes = set([]) if not chain_ids: chains = (structure.chains[0].id, structure.chains[1].id) else: chains = [structure.get_chain(chain_id) for chain_id in chain_ids] + for residue1, residue2 in get_residue_contact_pairs( pdb_path, structure, chains[0], chains[1], @@ -76,36 +76,33 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq if detail == 'residue': nodes.add(residue1) nodes.add(residue2) - elif detail == 'atom': for atom in residue1.atoms: nodes.add(atom) for atom in residue2.atoms: nodes.add(atom) + else: + raise TypeError('detail must be "atom" or "residue"') - if detail == 'residue': - return build_residue_graph(list(nodes), structure.id, max_edge_length), None - if detail == 'atom': - return build_atomic_graph(list(nodes), structure.id, max_edge_length), None - raise TypeError('detail must be "atom" or "residue"') + return Graph.build_graph(list(nodes), structure.id, max_edge_length), None + # if central_res + if not chain_ids: + chain: Chain = structure.chains[0] else: - if not chain_ids: - chain: Chain = structure.chains[0] - else: - chain = structure.get_chain(chain_ids) - residue = _get_residue(chain, central_res) - surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) - - try: - with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) - except FileNotFoundError: - pass - - if detail == 'residue': - return build_residue_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) - if detail == 'atom': - atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) - return build_atomic_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) - raise TypeError('detail must be "atom" or "residue"') + chain = structure.get_chain(chain_ids) + residue = _get_residue(chain, central_res) + surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) + + try: + with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: + chain.pssm = parse_pssm(f, chain) + except FileNotFoundError: + pass + + if detail == 'residue': + return Graph.build_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) + if detail == 'atom': + atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) + return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) + raise TypeError('detail must be "atom" or "residue"')