Skip to content

Commit

Permalink
refactor: unify build_graph functions
Browse files Browse the repository at this point in the history
previously separate functions for building atom and residue graphs
  • Loading branch information
DaniBodor committed Nov 22, 2023
1 parent 3c6e016 commit 9103e45
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 120 deletions.
2 changes: 1 addition & 1 deletion deeprank2/molstruct/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __repr__(self) -> str:


class Contact(Pair, ABC):
"""Parent class to bind `ResidueContact` and `ResidueContact` objects."""
"""Parent class to bind `ResidueContact` and `AtomicContact` objects."""


class ResidueContact(Contact):
Expand Down
11 changes: 5 additions & 6 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from deeprank2.molstruct.structure import PDBStructure
from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure,
get_surrounding_residues)
from deeprank2.utils.graph import (Graph, build_atomic_graph,
build_residue_graph)
from deeprank2.utils.graph import Graph
from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod
from deeprank2.utils.parsing.pssm import parse_pssm

Expand Down Expand Up @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'residue':
graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'atom':
residues.append(variant_residue)
atoms = set([])
Expand All @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph:
atoms.add(atom)
atoms = list(atoms)

graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length)

else:
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
Expand Down Expand Up @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'atom':
graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'residue':
residues_selected = list({atom.residue for atom in contact_atoms})
graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
structure = contact_atoms[0].residue.chain.model
Expand Down
139 changes: 58 additions & 81 deletions deeprank2/utils/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import os
from typing import Callable
Expand Down Expand Up @@ -318,92 +320,67 @@ def get_all_chains(self) -> list[str]:
return list(chains)


def build_atomic_graph(
atoms: list[Atom],
graph_id: str,
max_edge_length: float,
) -> Graph:
"""Builds a graph, using the atoms as nodes.
The max edge distance is in Ångströms.
"""

positions = np.empty((len(atoms), 3))
for atom_index, atom in enumerate(atoms):
positions[atom_index] = atom.position

distances = distance_matrix(positions, positions, p=2)
neighbours = distances < max_edge_length
atom_index_pairs = np.transpose(np.nonzero(neighbours))

graph = Graph(graph_id)
for index1, index2 in atom_index_pairs:
if index1 != index2:

node1 = Node(atoms[index1])
node2 = Node(atoms[index2])
contact = AtomicContact(node1.id, node2.id)

node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(Edge(contact))

return graph


def build_residue_graph( # pylint: disable=too-many-locals
residues: list[Residue],
graph_id: str,
max_edge_length: float,
) -> Graph:
"""Builds a graph, using the residues as nodes.
The max edge length is in Ångströms.
It's the shortest interatomic distance between two residues.
"""

# collect the set of atoms and remember which are on the same residue (by index)
atoms = []
atoms_residues = []
for residue_index, residue in enumerate(residues):
for atom in residue.atoms:
atoms.append(atom)
atoms_residues.append(residue_index)

atoms_residues = np.array(atoms_residues)

# calculate the distance matrix
positions = np.empty((len(atoms), 3))
for atom_index, atom in enumerate(atoms):
positions[atom_index] = atom.position
@staticmethod
def build_graph( # pylint: disable=too-many-locals
nodes: list[Atom] | list[Residue],
graph_id: str,
max_edge_length: float,
) -> Graph:
"""Builds a graph.
Args:
nodes (list[Atom] | list[Residue]): List of `Atom`s or `Residue`s to include in graph.
All nodes must be of same type.
graph_id (str): Human readable identifier for graph.
max_edge_length (float): Maximum distance between two nodes to connect them with an edge.
Returns:
Graph: Containing nodes (with positions) and edges.
Raises:
TypeError: if `nodes` argument contains a mix of different types.
"""

if all(isinstance(node, Atom) for node in nodes):
atoms = nodes
NodeContact = AtomicContact
elif all(isinstance(node, Residue) for node in nodes):
# collect the set of atoms and remember which are on the same residue (by index)
atoms = []
atoms_residues = []
for residue_index, residue in enumerate(nodes):
for atom in residue.atoms:
atoms.append(atom)
atoms_residues.append(residue_index)
atoms_residues = np.array(atoms_residues)
NodeContact = ResidueContact
else:
raise TypeError("All nodes in the graph must be of the same type.")

distances = distance_matrix(positions, positions, p=2)
neighbours = distances < max_edge_length
positions = np.empty((len(atoms), 3))
for atom_index, atom in enumerate(atoms):
positions[atom_index] = atom.position
neighbours = max_edge_length > distance_matrix(positions, positions, p=2)

atom_index_pairs = np.transpose(np.nonzero(neighbours))
index_pairs = np.transpose(np.nonzero(neighbours)) # atom pairs
if NodeContact == ResidueContact:
index_pairs = np.unique(atoms_residues[index_pairs], axis=0) # residue pairs

# point out the unique residues for the atom pairs
residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0)
graph = Graph(graph_id)

# build the graph
graph = Graph(graph_id)
for index1, index2 in residue_index_pairs:
if index1 != index2:
for index1, index2 in index_pairs:
if index1 != index2:

node1 = Node(residues[index1])
node2 = Node(residues[index2])
contact = ResidueContact(node1.id, node2.id)
node1 = Node(nodes[index1])
node2 = Node(nodes[index2])
contact = NodeContact(node1.id, node2.id)

node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position
node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

# The same residue will be added multiple times as a node,
# but the Graph class fixes this.
graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(Edge(contact))
# The same node will be added multiple times, but the Graph class fixes this.
graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(Edge(contact))

return graph
return graph
61 changes: 29 additions & 32 deletions tests/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from deeprank2.utils.buildgraph import (get_residue_contact_pairs,
get_structure,
get_surrounding_residues)
from deeprank2.utils.graph import (Graph, build_atomic_graph,
build_residue_graph)
from deeprank2.utils.graph import Graph
from deeprank2.utils.parsing.pssm import parse_pssm


Expand All @@ -25,7 +24,7 @@ def _get_residue(chain: Chain, number: int) -> Residue:

def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001
pdb_path: str,
detail: Literal['atomic', 'residue'],
detail: Literal['atom', 'residue'],
influence_radius: float,
max_edge_length: float,
central_res: int | None = None,
Expand All @@ -37,11 +36,11 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
Args:
pdb_path (str): Path of pdb file.
detail (Literal['atomic', 'residue']): Level of detail.
detail (Literal['atom', 'residue']): Type of graph to create.
influence_radius (float): max distance to include in graph.
max_edge_length (float): max distance to create an edge.
central_res (int | None, optional): Residue to center a single-chain graph around.
Use None to create a 2-chain graph, or any value for a single-chain graph
Use None to create a 2-chain graph, or any value for a single-chain graph.
Defaults to None.
variant (AminoAcid | None, optional): Amino acid to use as a variant amino acid.
Defaults to None.
Expand All @@ -52,7 +51,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
TypeError: if detail is set to anything other than 'residue' or 'atom'
Returns:
Graph: As generated by build_residue_graph or build_atomic_graph
Graph: As generated by Graph.build_graph
SingleResidueVariant: returns None if central_res is None
"""

Expand All @@ -62,12 +61,13 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
finally:
pdb._close() # pylint: disable=protected-access

if not central_res: # pylint: disable=no-else-raise
if not central_res:
nodes = set([])
if not chain_ids:
chains = (structure.chains[0].id, structure.chains[1].id)
else:
chains = [structure.get_chain(chain_id) for chain_id in chain_ids]

for residue1, residue2 in get_residue_contact_pairs(
pdb_path, structure,
chains[0], chains[1],
Expand All @@ -76,36 +76,33 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
if detail == 'residue':
nodes.add(residue1)
nodes.add(residue2)

elif detail == 'atom':
for atom in residue1.atoms:
nodes.add(atom)
for atom in residue2.atoms:
nodes.add(atom)
else:
raise TypeError('detail must be "atom" or "residue"')

if detail == 'residue':
return build_residue_graph(list(nodes), structure.id, max_edge_length), None
if detail == 'atom':
return build_atomic_graph(list(nodes), structure.id, max_edge_length), None
raise TypeError('detail must be "atom" or "residue"')
return Graph.build_graph(list(nodes), structure.id, max_edge_length), None

# if central_res
if not chain_ids:
chain: Chain = structure.chains[0]
else:
if not chain_ids:
chain: Chain = structure.chains[0]
else:
chain = structure.get_chain(chain_ids)
residue = _get_residue(chain, central_res)
surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius))

try:
with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f:
chain.pssm = parse_pssm(f, chain)
except FileNotFoundError:
pass

if detail == 'residue':
return build_residue_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant)
if detail == 'atom':
atoms = set(atom for residue in surrounding_residues for atom in residue.atoms)
return build_atomic_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant)
raise TypeError('detail must be "atom" or "residue"')
chain = structure.get_chain(chain_ids)
residue = _get_residue(chain, central_res)
surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius))

try:
with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f:
chain.pssm = parse_pssm(f, chain)
except FileNotFoundError:
pass

if detail == 'residue':
return Graph.build_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant)
if detail == 'atom':
atoms = set(atom for residue in surrounding_residues for atom in residue.atoms)
return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant)
raise TypeError('detail must be "atom" or "residue"')

0 comments on commit 9103e45

Please sign in to comment.