Skip to content

Commit

Permalink
Merge pull request #507 from DeepRank/506_buildgraph_unification_dbodor
Browse files Browse the repository at this point in the history
refactor: unify buildgraph.py and `build_graph` functions for atom and residue level graphs
  • Loading branch information
DaniBodor authored Nov 23, 2023
2 parents e92ae10 + a777cb7 commit 9d508c0
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 384 deletions.
4 changes: 4 additions & 0 deletions deeprank2/domain/aminoacidlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@
# pyrrolysine,
]

amino_acids_by_code = {amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids}
amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids}
amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids}

def convert_aa_nomenclature(aa: str, output_type: int | None = None):
try:
if len(aa) == 1:
Expand Down
6 changes: 3 additions & 3 deletions deeprank2/features/irc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pdb2sql

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain.aminoacidlist import amino_acids
from deeprank2.domain.aminoacidlist import amino_acids_by_code
from deeprank2.molstruct.aminoacid import Polarity
from deeprank2.molstruct.atom import Atom
from deeprank2.molstruct.residue import Residue, SingleResidueVariant
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str,
for chain1_res, chain2_residues in pdb2sql_contacts.items():
aa1_code = chain1_res[2]
try:
aa1 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa1_code][0]
aa1 = amino_acids_by_code[aa1_code]
except IndexError:
continue # skip keys that are not an amino acid

Expand All @@ -78,7 +78,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str,
for chain2_res in chain2_residues:
aa2_code = chain2_res[2]
try:
aa2 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa2_code][0]
aa2 = amino_acids_by_code[aa2_code]
except IndexError:
continue # skip keys that are not an amino acid

Expand Down
3 changes: 2 additions & 1 deletion deeprank2/molstruct/aminoacid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum

import numpy as np
from numpy.typing import NDArray


class Polarity(Enum):
Expand Down Expand Up @@ -108,7 +109,7 @@ def hydrogen_bond_acceptors(self) -> int:
return self._hydrogen_bond_acceptors

@property
def onehot(self) -> np.ndarray:
def onehot(self) -> NDArray:
if self._index is None:
raise ValueError(
f"Amino acid {self._name} index is not set, thus no onehot can be computed."
Expand Down
2 changes: 1 addition & 1 deletion deeprank2/molstruct/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __repr__(self) -> str:


class Contact(Pair, ABC):
"""Parent class to bind `ResidueContact` and `ResidueContact` objects."""
"""Parent class to bind `ResidueContact` and `AtomicContact` objects."""


class ResidueContact(Contact):
Expand Down
2 changes: 1 addition & 1 deletion deeprank2/molstruct/residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __repr__(self) -> str:

@property
def position(self) -> np.array:
return np.mean([atom.position for atom in self._atoms], axis=0)
return self.get_center()

def get_center(self) -> NDArray:
"""Find the center position of a `Residue`.
Expand Down
11 changes: 5 additions & 6 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from deeprank2.molstruct.structure import PDBStructure
from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure,
get_surrounding_residues)
from deeprank2.utils.graph import (Graph, build_atomic_graph,
build_residue_graph)
from deeprank2.utils.graph import Graph
from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod
from deeprank2.utils.parsing.pssm import parse_pssm

Expand Down Expand Up @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'residue':
graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'atom':
residues.append(variant_residue)
atoms = set([])
Expand All @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph:
atoms.add(atom)
atoms = list(atoms)

graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length)

else:
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
Expand Down Expand Up @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'atom':
graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'residue':
residues_selected = list({atom.residue for atom in contact_atoms})
graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
structure = contact_atoms[0].residue.chain.model
Expand Down
Loading

0 comments on commit 9d508c0

Please sign in to comment.