Skip to content

Commit

Permalink
use separate radius and max_edge_dist in tests and notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Nov 3, 2023
1 parent d0d5957 commit 9645914
Show file tree
Hide file tree
Showing 13 changed files with 170 additions and 88 deletions.
6 changes: 4 additions & 2 deletions tests/data/hdf5/_generate_testdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
"models_folder_name = 'exp_nmers_all_HLA_quantitative'\n",
"data = 'pMHCI'\n",
"resolution = 'residue' # either 'residue' or 'atom'\n",
"distance_cutoff = 15 # max distance in Å between two interacting residues/atoms of two proteins\n",
"interaction_radius = 15 # max distance in Å between two interacting residues/atoms of two proteins\n",
"max_edge_distance = 15 # max distance in Å between to create an edge\n",
"\n",
"csv_file_path = f'{project_folder}data/external/processed/I/{csv_file_name}'\n",
"models_folder_path = f'{project_folder}data/{data}/features_input_folder/{models_folder_name}'\n",
Expand Down Expand Up @@ -130,7 +131,8 @@
" pdb_path = pdb_files[i],\n",
" resolution = \"residue\",\n",
" chain_ids = [\"M\", \"P\"],\n",
" distance_cutoff = distance_cutoff,\n",
" interaction_radius = interaction_radius,\n",
" max_edge_distance = max_edge_distance,\n",
" targets = {\n",
" 'binary': int(float(bas[i]) <= 500), # binary target value\n",
" 'BA': bas[i], # continuous target value\n",
Expand Down
28 changes: 15 additions & 13 deletions tests/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union

from pdb2sql import pdb2sql

Expand All @@ -25,19 +25,21 @@ def _get_residue(chain: Chain, number: int) -> Residue:

def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001
pdb_path: str,
cutoff: float,
detail: str,
detail: Literal['atomic', 'residue'],
interaction_radius: float,
max_edge_distance: float,
central_res: Optional[int] = None,
variant: Optional[AminoAcid] = None,
chain_ids: Optional[Union[str, Tuple[str, str]]] = None,
) -> Union[Graph, Tuple[Graph, SingleResidueVariant]]:
) -> Tuple[Graph, Union[SingleResidueVariant, None]]:

""" Creates a Graph object for feature tests.
Args:
pdb_path (str): Path of pdb file.
cutoff (float): Cutoff distance of the graph (also used as radius for single-chain graphs).
detail (str): Level of detail. Accepted values are: 'residue' or 'atom'.
detail (Literal['atomic', 'residue']): Level of detail.
interaction_radius (float): max distance to include in graph.
max_edge_distance (float): max distance to create an edge.
central_res (Optional[int], optional): Residue to center a single-chain graph around.
Use None to create a 2-chain graph, or any value for a single-chain graph
Defaults to None.
Expand All @@ -51,7 +53,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
Returns:
Graph: As generated by build_residue_graph or build_atomic_graph
SingleResidueVariant: Only resturned if central_res is not None
SingleResidueVariant: returns None if central_res is None
"""

pdb = pdb2sql(pdb_path)
Expand All @@ -69,7 +71,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
for residue1, residue2 in get_residue_contact_pairs(
pdb_path, structure,
chains[0], chains[1],
cutoff
interaction_radius
):
if detail == 'residue':
nodes.add(residue1)
Expand All @@ -82,9 +84,9 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
nodes.add(atom)

if detail == 'residue':
return build_residue_graph(list(nodes), structure.id, cutoff)
return build_residue_graph(list(nodes), structure.id, max_edge_distance), None
if detail == 'atom':
return build_atomic_graph(list(nodes), structure.id, cutoff)
return build_atomic_graph(list(nodes), structure.id, max_edge_distance), None
raise TypeError('detail must be "atom" or "residue"')

else:
Expand All @@ -93,7 +95,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
else:
chain = structure.get_chain(chain_ids)
residue = _get_residue(chain, central_res)
surrounding_residues = list(get_surrounding_residues(structure, residue, cutoff))
surrounding_residues = list(get_surrounding_residues(structure, residue, interaction_radius))

try:
with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f:
Expand All @@ -102,8 +104,8 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq
pass

if detail == 'residue':
return build_residue_graph(surrounding_residues, structure.id, cutoff), SingleResidueVariant(residue, variant)
return build_residue_graph(surrounding_residues, structure.id, max_edge_distance), SingleResidueVariant(residue, variant)
if detail == 'atom':
atoms = set(atom for residue in surrounding_residues for atom in residue.atoms)
return build_atomic_graph(list(atoms), structure.id, cutoff), SingleResidueVariant(residue, variant)
return build_atomic_graph(list(atoms), structure.id, max_edge_distance), SingleResidueVariant(residue, variant)
raise TypeError('detail must be "atom" or "residue"')
25 changes: 18 additions & 7 deletions tests/features/test_components.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
import numpy as np
from deeprank2.domain.aminoacidlist import glycine, serine
from deeprank2.features.components import add_features

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain.aminoacidlist import glycine, serine
from deeprank2.features.components import add_features

from . import build_testgraph


def test_atom_features():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, _ = build_testgraph(pdb_path, 10, 'atom', 25)

graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=10,
max_edge_distance=10,
central_res=25,
)
add_features(pdb_path, graph)

assert not any(np.isnan(node.features[Nfeat.ATOMCHARGE]) for node in graph.nodes)
assert not any(np.isnan(node.features[Nfeat.PDBOCCUPANCY]) for node in graph.nodes)


def test_aminoacid_features():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, variant = build_testgraph(pdb_path, 10, 'residue', 25, serine)
graph, variant = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=10,
max_edge_distance=10,
central_res=25,
variant=serine,
)
add_features(pdb_path, graph, variant)

node = graph.nodes[25].id

for node in graph.nodes:
if node.id == variant.residue: # GLY -> SER
assert sum(node.features[Nfeat.RESTYPE]) == 1
Expand Down
32 changes: 27 additions & 5 deletions tests/features/test_conservation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import numpy as np
import pytest
from deeprank2.domain.aminoacidlist import alanine
from deeprank2.features.conservation import add_features

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain.aminoacidlist import alanine
from deeprank2.features.conservation import add_features

from . import build_testgraph


def test_conservation_residue():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, variant = build_testgraph(pdb_path, 10, 'residue', 25, alanine)
graph, variant = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=10,
max_edge_distance=10,
central_res=25,
variant=alanine,
)
add_features(pdb_path, graph, variant)

for feature_name in (
Expand All @@ -24,7 +31,14 @@ def test_conservation_residue():

def test_conservation_atom():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, variant = build_testgraph(pdb_path, 10, 'atom', 25, alanine)
graph, variant = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=10,
max_edge_distance=10,
central_res=25,
variant=alanine,
)
add_features(pdb_path, graph, variant)

for feature_name in (
Expand All @@ -38,6 +52,14 @@ def test_conservation_atom():

def test_no_pssm_file_error():
pdb_path = "tests/data/pdb/1CRN/1CRN.pdb"
graph, variant = build_testgraph(pdb_path, 10, 'residue', 17, alanine)
graph, variant = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=10,
max_edge_distance=10,
central_res=17,
variant=alanine,
)

with pytest.raises(FileNotFoundError):
add_features(pdb_path, graph, variant)
20 changes: 14 additions & 6 deletions tests/features/test_exposure.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from deeprank2.features.exposure import add_features
from deeprank2.utils.graph import Graph

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.features.exposure import add_features
from deeprank2.utils.graph import Graph

from . import build_testgraph

Expand All @@ -19,15 +19,23 @@ def _run_assertions(graph: Graph):

def test_exposure_residue():
pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb"
graph = build_testgraph(pdb_path, 8.5, 'residue')

graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=8.5,
max_edge_distance=8.5,
)
add_features(pdb_path, graph)
_run_assertions(graph)


def test_exposure_atom():
pdb_path = "tests/data/pdb/1ak4/1ak4.pdb"
graph = build_testgraph(pdb_path, 4.5, 'atom')

graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=4.5,
max_edge_distance=4.5,
)
add_features(pdb_path, graph)
_run_assertions(graph)
20 changes: 14 additions & 6 deletions tests/features/test_irc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from deeprank2.features.irc import add_features
from deeprank2.utils.graph import Graph

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.features.irc import add_features
from deeprank2.utils.graph import Graph

from . import build_testgraph

Expand All @@ -25,15 +25,23 @@ def _run_assertions(graph: Graph):

def test_irc_residue():
pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb"
graph = build_testgraph(pdb_path, 8.5, 'residue')

graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=8.5,
max_edge_distance=8.5,
)
add_features(pdb_path, graph)
_run_assertions(graph)


def test_irc_atom():
pdb_path = "tests/data/pdb/1A0Z/1A0Z.pdb"
graph = build_testgraph(pdb_path, 4.5, 'atom')

graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=4.5,
max_edge_distance=4.5,
)
add_features(pdb_path, graph)
_run_assertions(graph)
18 changes: 14 additions & 4 deletions tests/features/test_secondary_structure.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import numpy as np

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.features.secondary_structure import (SecondarySctructure,
_classify_secstructure,
add_features)

from deeprank2.domain import nodestorage as Nfeat

from . import build_testgraph


def test_secondary_structure_residue():
test_case = '9api' # properly formatted pdb file
pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb"
graph = build_testgraph(pdb_path, 10, 'residue')
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=10,
max_edge_distance=10,
)
add_features(pdb_path, graph)

# Create a list of node information (residue number, chain ID, and secondary structure features)
Expand Down Expand Up @@ -53,7 +58,12 @@ def test_secondary_structure_residue():
def test_secondary_structure_atom():
test_case = '1ak4' # ATOM list
pdb_path = f"tests/data/pdb/{test_case}/{test_case}.pdb"
graph = build_testgraph(pdb_path, 4.5, 'atom')
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=4.5,
max_edge_distance=4.5,
)
add_features(pdb_path, graph)

# Create a list of node information (residue number, chain ID, and secondary structure features)
Expand Down
32 changes: 27 additions & 5 deletions tests/features/test_surfacearea.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from deeprank2.features.surfacearea import add_features

from deeprank2.domain import nodestorage as Nfeat
from deeprank2.features.surfacearea import add_features

from . import build_testgraph

Expand All @@ -28,7 +28,12 @@ def _find_atom_node(graph, chain_id, residue_number, atom_name):

def test_bsa_residue():
pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb"
graph = build_testgraph(pdb_path, 8.5, 'residue')
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=8.5,
max_edge_distance=8.5,
)
add_features(pdb_path, graph)

# chain B ASP 93, at interface
Expand All @@ -38,7 +43,12 @@ def test_bsa_residue():

def test_bsa_atom():
pdb_path = "tests/data/pdb/1ATN/1ATN_1w.pdb"
graph = build_testgraph(pdb_path, 4.5, 'atom')
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=4.5,
max_edge_distance=4.5,
)
add_features(pdb_path, graph)

# chain B ASP 93, at interface
Expand All @@ -48,7 +58,13 @@ def test_bsa_atom():

def test_sasa_residue():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, _ = build_testgraph(pdb_path, 10, 'residue', 108)
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='residue',
interaction_radius=10,
max_edge_distance=10,
central_res=108,
)
add_features(pdb_path, graph)

# check for NaN
Expand All @@ -67,7 +83,13 @@ def test_sasa_residue():

def test_sasa_atom():
pdb_path = "tests/data/pdb/101M/101M.pdb"
graph, _ = build_testgraph(pdb_path, 10, 'atom', 108)
graph, _ = build_testgraph(
pdb_path=pdb_path,
detail='atom',
interaction_radius=10,
max_edge_distance=10,
central_res=108,
)
add_features(pdb_path, graph)

# check for NaN
Expand Down
Loading

0 comments on commit 9645914

Please sign in to comment.