Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: unify buildgraph.py and build_graph functions for atom and residue level graphs #507

Merged
merged 6 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deeprank2/domain/aminoacidlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,10 @@
# pyrrolysine,
]

amino_acids_by_code = {amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids}
amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids}
amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids}

def convert_aa_nomenclature(aa: str, output_type: int | None = None):
try:
if len(aa) == 1:
Expand Down
6 changes: 3 additions & 3 deletions deeprank2/features/irc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pdb2sql

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain.aminoacidlist import amino_acids
from deeprank2.domain.aminoacidlist import amino_acids_by_code
from deeprank2.molstruct.aminoacid import Polarity
from deeprank2.molstruct.atom import Atom
from deeprank2.molstruct.residue import Residue, SingleResidueVariant
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str,
for chain1_res, chain2_residues in pdb2sql_contacts.items():
aa1_code = chain1_res[2]
try:
aa1 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa1_code][0]
aa1 = amino_acids_by_code[aa1_code]
except IndexError:
continue # skip keys that are not an amino acid

Expand All @@ -78,7 +78,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str,
for chain2_res in chain2_residues:
aa2_code = chain2_res[2]
try:
aa2 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa2_code][0]
aa2 = amino_acids_by_code[aa2_code]
except IndexError:
continue # skip keys that are not an amino acid

Expand Down
3 changes: 2 additions & 1 deletion deeprank2/molstruct/aminoacid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum

import numpy as np
from numpy.typing import NDArray


class Polarity(Enum):
Expand Down Expand Up @@ -108,7 +109,7 @@ def hydrogen_bond_acceptors(self) -> int:
return self._hydrogen_bond_acceptors

@property
def onehot(self) -> np.ndarray:
def onehot(self) -> NDArray:
if self._index is None:
raise ValueError(
f"Amino acid {self._name} index is not set, thus no onehot can be computed."
Expand Down
2 changes: 1 addition & 1 deletion deeprank2/molstruct/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __repr__(self) -> str:


class Contact(Pair, ABC):
"""Parent class to bind `ResidueContact` and `ResidueContact` objects."""
"""Parent class to bind `ResidueContact` and `AtomicContact` objects."""


class ResidueContact(Contact):
Expand Down
2 changes: 1 addition & 1 deletion deeprank2/molstruct/residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __repr__(self) -> str:

@property
def position(self) -> np.array:
return np.mean([atom.position for atom in self._atoms], axis=0)
return self.get_center()

def get_center(self) -> NDArray:
"""Find the center position of a `Residue`.
Expand Down
11 changes: 5 additions & 6 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from deeprank2.molstruct.structure import PDBStructure
from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure,
get_surrounding_residues)
from deeprank2.utils.graph import (Graph, build_atomic_graph,
build_residue_graph)
from deeprank2.utils.graph import Graph
from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod
from deeprank2.utils.parsing.pssm import parse_pssm

Expand Down Expand Up @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'residue':
graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'atom':
residues.append(variant_residue)
atoms = set([])
Expand All @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph:
atoms.add(atom)
atoms = list(atoms)

graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length)

else:
raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.")
Expand Down Expand Up @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph:

# build the graph
if self.resolution == 'atom':
graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length)
elif self.resolution == 'residue':
residues_selected = list({atom.residue for atom in contact_atoms})
graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length)
graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length)

graph.center = np.mean([atom.position for atom in contact_atoms], axis=0)
structure = contact_atoms[0].residue.chain.model
Expand Down
Loading
Loading