diff --git a/deeprank2/domain/aminoacidlist.py b/deeprank2/domain/aminoacidlist.py index c5727bf3c..9906e2a32 100644 --- a/deeprank2/domain/aminoacidlist.py +++ b/deeprank2/domain/aminoacidlist.py @@ -350,6 +350,10 @@ # pyrrolysine, ] +amino_acids_by_code = {amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids} +amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids} +amino_acids_by_name = {amino_acid.name: amino_acid for amino_acid in amino_acids} + def convert_aa_nomenclature(aa: str, output_type: int | None = None): try: if len(aa) == 1: diff --git a/deeprank2/features/irc.py b/deeprank2/features/irc.py index acaf40c17..4d7583f9d 100644 --- a/deeprank2/features/irc.py +++ b/deeprank2/features/irc.py @@ -4,7 +4,7 @@ import pdb2sql from deeprank2.domain import nodestorage as Nfeat -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids_by_code from deeprank2.molstruct.aminoacid import Polarity from deeprank2.molstruct.atom import Atom from deeprank2.molstruct.residue import Residue, SingleResidueVariant @@ -67,7 +67,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, for chain1_res, chain2_residues in pdb2sql_contacts.items(): aa1_code = chain1_res[2] try: - aa1 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa1_code][0] + aa1 = amino_acids_by_code[aa1_code] except IndexError: continue # skip keys that are not an amino acid @@ -78,7 +78,7 @@ def get_IRCs(pdb_path: str, chains: list[str], cutoff: float = 5.5) -> dict[str, for chain2_res in chain2_residues: aa2_code = chain2_res[2] try: - aa2 = [amino_acid for amino_acid in amino_acids if amino_acid.three_letter_code == aa2_code][0] + aa2 = amino_acids_by_code[aa2_code] except IndexError: continue # skip keys that are not an amino acid diff --git a/deeprank2/molstruct/aminoacid.py b/deeprank2/molstruct/aminoacid.py index 88c615a87..189b5a564 100644 --- a/deeprank2/molstruct/aminoacid.py +++ b/deeprank2/molstruct/aminoacid.py @@ -1,6 +1,7 @@ from enum import Enum import numpy as np +from numpy.typing import NDArray class Polarity(Enum): @@ -108,7 +109,7 @@ def hydrogen_bond_acceptors(self) -> int: return self._hydrogen_bond_acceptors @property - def onehot(self) -> np.ndarray: + def onehot(self) -> NDArray: if self._index is None: raise ValueError( f"Amino acid {self._name} index is not set, thus no onehot can be computed." diff --git a/deeprank2/molstruct/pair.py b/deeprank2/molstruct/pair.py index ecdb5febf..bdf7f5423 100644 --- a/deeprank2/molstruct/pair.py +++ b/deeprank2/molstruct/pair.py @@ -41,7 +41,7 @@ def __repr__(self) -> str: class Contact(Pair, ABC): - """Parent class to bind `ResidueContact` and `ResidueContact` objects.""" + """Parent class to bind `ResidueContact` and `AtomicContact` objects.""" class ResidueContact(Contact): diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index e9938b219..552d1a015 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -97,7 +97,7 @@ def __repr__(self) -> str: @property def position(self) -> np.array: - return np.mean([atom.position for atom in self._atoms], axis=0) + return self.get_center() def get_center(self) -> NDArray: """Find the center position of a `Residue`. diff --git a/deeprank2/query.py b/deeprank2/query.py index 341d348a0..f0fefd113 100644 --- a/deeprank2/query.py +++ b/deeprank2/query.py @@ -25,8 +25,7 @@ from deeprank2.molstruct.structure import PDBStructure from deeprank2.utils.buildgraph import (get_contact_atoms, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.grid import Augmentation, GridSettings, MapMethod from deeprank2.utils.parsing.pssm import parse_pssm @@ -298,7 +297,7 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'residue': - graph = build_residue_graph(residues, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues, self.get_query_id(), self.max_edge_length) elif self.resolution == 'atom': residues.append(variant_residue) atoms = set([]) @@ -308,7 +307,7 @@ def _build_helper(self) -> Graph: atoms.add(atom) atoms = list(atoms) - graph = build_atomic_graph(atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(atoms, self.get_query_id(), self.max_edge_length) else: raise NotImplementedError(f"No function exists to build graphs with resolution of {self.resolution}.") @@ -367,10 +366,10 @@ def _build_helper(self) -> Graph: # build the graph if self.resolution == 'atom': - graph = build_atomic_graph(contact_atoms, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(contact_atoms, self.get_query_id(), self.max_edge_length) elif self.resolution == 'residue': residues_selected = list({atom.residue for atom in contact_atoms}) - graph = build_residue_graph(residues_selected, self.get_query_id(), self.max_edge_length) + graph = Graph.build_graph(residues_selected, self.get_query_id(), self.max_edge_length) graph.center = np.mean([atom.position for atom in contact_atoms], axis=0) structure = contact_atoms[0].residue.chain.model diff --git a/deeprank2/utils/buildgraph.py b/deeprank2/utils/buildgraph.py index 56d873e38..83511d05a 100644 --- a/deeprank2/utils/buildgraph.py +++ b/deeprank2/utils/buildgraph.py @@ -2,10 +2,11 @@ import os import numpy as np -from pdb2sql import interface as get_interface +from pdb2sql import interface as pdb2sql_interface +from pdb2sql import pdb2sql as pdb2sql_object from scipy.spatial import distance_matrix -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids_by_code from deeprank2.molstruct.atom import Atom, AtomicElement from deeprank2.molstruct.pair import Pair from deeprank2.molstruct.residue import Residue @@ -14,196 +15,112 @@ _log = logging.getLogger(__name__) -def _add_atom_to_residue(atom, residue): +def _add_atom_to_residue(atom: Atom, residue: Residue): + """Adds an `Atom` to a `Residue` if not already there. + + If no matching atom is found, add the current atom to the residue. + If there's another atom with the same name, choose the one with the highest occupancy. + """ + for other_atom in residue.atoms: if other_atom.name == atom.name: - # Don't allow two atoms with the same name, pick the highest - # occupancy if other_atom.occupancy < atom.occupancy: other_atom.change_altloc(atom) return - - # not there yet, add it residue.add_atom(atom) -_amino_acids_by_code = { - amino_acid.three_letter_code: amino_acid for amino_acid in amino_acids -} - - -_elements_by_name = {element.name: element for element in AtomicElement} - +def _add_atom_data_to_structure( + structure: PDBStructure, + pdb_obj: pdb2sql_object, + **kwargs +): + """This subroutine retrieves pdb2sql atomic data for `PDBStructure` objects as defined in DeepRank2. -def _add_atom_data_to_structure(structure: PDBStructure, # pylint: disable=too-many-arguments, too-many-locals - x: float, y: float, z: float, - atom_name: str, - altloc: str, occupancy: float, - element_name: str, - chain_id: str, - residue_number: int, - residue_name: str, - insertion_code: str): - - """ - This is a subroutine, to be used in other methods for converting pdb2sql atomic data into a - deeprank structure object. It should be called for one atom. + This function should be called for one atom at a time. Args: - structure (:class:`PDBStructure`): Where this atom should be added to. - x (float): x-coordinate of atom. - y (float): y-coordinate of atom. - z (float): z-coordinate of atom. - atom_name (str): Name of atom: 'CA', 'C', 'N', 'O', 'CB', etc. - altloc (str): Pdb alternative location id for this atom (can be empty): 'A', 'B', 'C', etc. - occupancy (float): Pdb occupancy of this atom, ranging from 0.0 to 1.0. Should be used with altloc. - element_name (str): Pdb element symbol of this atom: 'C', 'O', 'H', 'N', 'S'. - chain_id (str): Pdb chain identifier: 'A', 'B', 'C', etc. - residue_number (int): Pdb residue number, a positive integer. - residue_name (str): Pdb residue name: "ALA", "CYS", "ASP", etc. - insertion_code (str): Pdb residue insertion code (can be empty) : '', 'A', 'B', 'C', etc. + structure (:class:`PDBStructure`): The structure to which this atom should be added to. + pdb (pdb2sql_object): The `pdb2sql` object to retrieve the data from. + kwargs: as required by the get function for the `pdb2sql` object. """ - # Make sure not to take the same atom twice. - if altloc is not None and altloc != "" and altloc != "A": - return - - # We use None to indicate that the residue has no insertion code. - if insertion_code == "": - insertion_code = None - - # The amino acid is only valid when we deal with protein residues. - if residue_name in _amino_acids_by_code: - amino_acid = _amino_acids_by_code[residue_name] - else: - amino_acid = None - - # Turn the x,y,z into a vector: - atom_position = np.array([x, y, z]) - - # Init chain. - if not structure.has_chain(chain_id): - - chain = Chain(structure, chain_id) - structure.add_chain(chain) - else: - chain = structure.get_chain(chain_id) - - # Init residue. - if not chain.has_residue(residue_number, insertion_code): - - residue = Residue(chain, residue_number, amino_acid, insertion_code) - chain.add_residue(residue) - else: - residue = chain.get_residue(residue_number, insertion_code) - - # Init atom. - atom = Atom( - residue, atom_name, _elements_by_name[element_name], atom_position, occupancy - ) - _add_atom_to_residue(atom, residue) + pdb2sql_columns = "x,y,z,name,altLoc,occ,element,chainID,resSeq,resName,iCode" + data_keys = pdb2sql_columns.split(sep=',') + for data_values in pdb_obj.get(pdb2sql_columns, **kwargs): + atom_data = dict(zip(data_keys, data_values)) + + # exit function if this atom is already part of the structure + if atom_data["altLoc"] not in (None, "", "A"): + return + + atom_data["iCode"] = None if atom_data["iCode"] == "" else atom_data["iCode"] + + try: + atom_data["aa"] = amino_acids_by_code[atom_data["resName"]] + except KeyError: + atom_data["aa"] = None + atom_data["coordinates"] = np.array(data_values[:3]) + + if not structure.has_chain(atom_data["chainID"]): + structure.add_chain(Chain(structure, atom_data["chainID"])) + chain = structure.get_chain(atom_data["chainID"]) + + if not chain.has_residue(atom_data["resSeq"], atom_data["iCode"]): + chain.add_residue(Residue(chain, atom_data["resSeq"], atom_data["aa"], atom_data["iCode"])) + residue = chain.get_residue(atom_data["resSeq"], atom_data["iCode"]) + + atom = Atom( + residue, + atom_data["name"], + AtomicElement[atom_data["element"]], + atom_data["coordinates"], + atom_data["occ"], + ) + _add_atom_to_residue(atom, residue) -def get_structure(pdb, id_: str) -> PDBStructure: +def get_structure(pdb_obj: pdb2sql_object, id_: str) -> PDBStructure: """Builds a structure from rows in a pdb file. Args: pdb (pdb2sql object): The pdb structure that we're investigating. - id (str): Unique id for the pdb structure. + id_ (str): Unique id for the pdb structure. Returns: PDBStructure: The structure object, giving access to chains, residues, atoms. """ - - # We need these intermediary dicts to keep track of which residues and - # chains have already been created. structure = PDBStructure(id_) - - # Iterate over the atom output from pdb2sql - for row in pdb.get( - "x,y,z,rowID,name,altLoc,occ,element,chainID,resSeq,resName,iCode", model=0 - ): - - ( - x, - y, - z, - _, - atom_name, - altloc, - occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code, - ) = row - - _add_atom_data_to_structure(structure, - x, y, z, - atom_name, - altloc, occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code) - + _add_atom_data_to_structure(structure, pdb_obj, model=0) return structure -def get_contact_atoms( # pylint: disable=too-many-locals +def get_contact_atoms( pdb_path: str, chain_ids: list[str], influence_radius: float ) -> list[Atom]: """Gets the contact atoms from pdb2sql and wraps them in python objects.""" - interface = get_interface(pdb_path) + interface = pdb2sql_interface(pdb_path) + pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] + structure = PDBStructure(f"contact_atoms_{pdb_name}") + try: atom_indexes = interface.get_contact_atoms( cutoff=influence_radius, chain1=chain_ids[0], chain2=chain_ids[1], ) - rows = interface.get( - "x,y,z,name,element,altLoc,occ,chainID,resSeq,resName,iCode", - rowID=atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] - ) + pdb_rowID = atom_indexes[chain_ids[0]] + atom_indexes[chain_ids[1]] + _add_atom_data_to_structure(structure, interface, rowID=pdb_rowID) finally: interface._close() # pylint: disable=protected-access - pdb_name = os.path.splitext(os.path.basename(pdb_path))[0] - structure = PDBStructure(f"contact_atoms_{pdb_name}") - - for row in rows: - ( - x, - y, - z, - atom_name, - element_name, - altloc, - occupancy, - chain_id, - residue_number, - residue_name, - insertion_code - ) = row - - _add_atom_data_to_structure(structure, - x, y, z, - atom_name, - altloc, occupancy, - element_name, - chain_id, - residue_number, - residue_name, - insertion_code) - return structure.get_atoms() -def get_residue_contact_pairs( # pylint: disable=too-many-locals + +def get_residue_contact_pairs( pdb_path: str, structure: PDBStructure, chain_id1: str, @@ -224,7 +141,7 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals """ # Find out which residues are pairs - interface = get_interface(pdb_path) + interface = pdb2sql_interface(pdb_path) try: contact_residues = interface.get_contact_residues( cutoff=influence_radius, @@ -237,50 +154,36 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals # Map to residue objects residue_pairs = set([]) - for residue_key1, _ in contact_residues.items(): - residue_chain_id1, residue_number1, residue_name1 = residue_key1 - - chain1 = structure.get_chain(residue_chain_id1) - - residue1 = None - for residue in chain1.residues: - if ( - residue.number == residue_number1 - and residue.amino_acid is not None - and residue.amino_acid.three_letter_code == residue_name1 - ): - residue1 = residue - break - else: - raise ValueError( - f"Not found: {pdb_path} {residue_chain_id1} {residue_number1} {residue_name1}" - ) - - for residue_chain_id2, residue_number2, residue_name2 in contact_residues[ # pylint: disable=unnecessary-dict-index-lookup - residue_key1 - ]: - - chain2 = structure.get_chain(residue_chain_id2) - - residue2 = None - for residue in chain2.residues: - if ( - residue.number == residue_number2 - and residue.amino_acid is not None - and residue.amino_acid.three_letter_code == residue_name2 - ): - residue2 = residue - break - else: - raise ValueError( - f"Not found: {pdb_path} {residue_chain_id2} {residue_number2} {residue_name2}" - ) - + for residue_key1, residue_contacts in contact_residues.items(): + residue1 = _get_residue_from_key(structure, residue_key1) + for residue_key2 in residue_contacts: + residue2 = _get_residue_from_key(structure, residue_key2) residue_pairs.add(Pair(residue1, residue2)) return residue_pairs +def _get_residue_from_key( + structure: PDBStructure, + residue_key: tuple[str, int, str], +) -> Residue: + """Returns a residue object given a pdb2sql-formatted residue key.""" + + residue_chain_id, residue_number, residue_name = residue_key + chain = structure.get_chain(residue_chain_id) + + for residue in chain.residues: + if ( + residue.number == residue_number + and residue.amino_acid is not None + and residue.amino_acid.three_letter_code == residue_name + ): + return residue + raise ValueError( + f"Residue ({residue_key}) not found in {structure.id}." + ) + + def get_surrounding_residues( structure: Chain | PDBStructure, residue: Residue, @@ -289,30 +192,23 @@ def get_surrounding_residues( """Get the residues that lie within a radius around a residue. Args: - structure (Union[:class:`Chain`, :class:`PDBStructure`]): The structure to take residues from. + structure (:class:`Chain` | :class:`PDBStructure`): The structure to take residues from. residue (:class:`Residue`): The residue in the structure. radius (float): Max distance in Ångström between atoms of the residue and the other residues. Returns: - (a set of deeprank residues): The surrounding residues. + list[:class:`Residue`]: The surrounding residues. """ structure_atoms = structure.get_atoms() - residue_atoms = residue.atoms - structure_atom_positions = [atom.position for atom in structure_atoms] - residue_atom_positions = [atom.position for atom in residue_atoms] - - distances = distance_matrix(structure_atom_positions, residue_atom_positions, p=2) - - close_residues = set([]) + residue_atom_positions = [atom.position for atom in residue.atoms] + pairwise_distances = distance_matrix(structure_atom_positions, residue_atom_positions, p=2) + surrounding_residues = set([]) for structure_atom_index, structure_atom in enumerate(structure_atoms): - - shortest_distance = np.min(distances[structure_atom_index, :]) - + shortest_distance = np.min(pairwise_distances[structure_atom_index, :]) if shortest_distance < radius: + surrounding_residues.add(structure_atom.residue) - close_residues.add(structure_atom.residue) - - return list(close_residues) + return list(surrounding_residues) diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 4ab0546e1..fafa7519d 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os from typing import Callable @@ -59,7 +61,6 @@ def __init__(self, id_: Atom | Residue): raise TypeError(type(id_)) self.id = id_ - self.features = {} @property @@ -319,97 +320,67 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( # pylint: disable=too-many-locals - atoms: list[Atom], graph_id: str, max_edge_length: float -) -> Graph: - """Builds a graph, using the atoms as nodes. - - The max edge distance is in Ångströms. - """ - - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position - - distances = distance_matrix(positions, positions, p=2) - neighbours = distances < max_edge_length - - graph = Graph(graph_id) - for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)): - if atom1_index != atom2_index: - - atom1 = atoms[atom1_index] - atom2 = atoms[atom2_index] - contact = AtomicContact(atom1, atom2) - - node1 = Node(atom1) - node2 = Node(atom2) - node1.features[Nfeat.POSITION] = atom1.position - node2.features[Nfeat.POSITION] = atom2.position - - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(Edge(contact)) - - return graph - - -def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], graph_id: str, max_edge_length: float -) -> Graph: - """Builds a graph, using the residues as nodes. - - The max edge distance is in Ångströms. - It's the shortest interatomic distance between two residues. - """ - - # collect the set of atoms and remember which are on the same residue (by index) - atoms = [] - atoms_residues = [] - for residue_index, residue in enumerate(residues): - for atom in residue.atoms: - atoms.append(atom) - atoms_residues.append(residue_index) - - atoms_residues = np.array(atoms_residues) - - # calculate the distance matrix - positions = np.empty((len(atoms), 3)) - for atom_index, atom in enumerate(atoms): - positions[atom_index] = atom.position - - distances = distance_matrix(positions, positions, p=2) - - # determine which atoms are close enough - neighbours = distances < max_edge_length - - atom_index_pairs = np.transpose(np.nonzero(neighbours)) - - # point out the unique residues for the atom pairs - residue_index_pairs = np.unique(atoms_residues[atom_index_pairs], axis=0) + @staticmethod + def build_graph( # pylint: disable=too-many-locals + nodes: list[Atom] | list[Residue], + graph_id: str, + max_edge_length: float, + ) -> Graph: + """Builds a graph. + + Args: + nodes (list[Atom] | list[Residue]): List of `Atom`s or `Residue`s to include in graph. + All nodes must be of same type. + graph_id (str): Human readable identifier for graph. + max_edge_length (float): Maximum distance between two nodes to connect them with an edge. + + Returns: + Graph: Containing nodes (with positions) and edges. + + Raises: + TypeError: if `nodes` argument contains a mix of different types. + """ + + if all(isinstance(node, Atom) for node in nodes): + atoms = nodes + NodeContact = AtomicContact + elif all(isinstance(node, Residue) for node in nodes): + # collect the set of atoms and remember which are on the same residue (by index) + atoms = [] + atoms_residues = [] + for residue_index, residue in enumerate(nodes): + for atom in residue.atoms: + atoms.append(atom) + atoms_residues.append(residue_index) + atoms_residues = np.array(atoms_residues) + NodeContact = ResidueContact + else: + raise TypeError("All nodes in the graph must be of the same type.") - # build the graph - graph = Graph(graph_id) - for residue1_index, residue2_index in residue_index_pairs: + positions = np.empty((len(atoms), 3)) + for atom_index, atom in enumerate(atoms): + positions[atom_index] = atom.position + neighbours = max_edge_length > distance_matrix(positions, positions, p=2) - residue1: Residue = residues[residue1_index] - residue2: Residue = residues[residue2_index] + index_pairs = np.transpose(np.nonzero(neighbours)) # atom pairs + if NodeContact == ResidueContact: + index_pairs = np.unique(atoms_residues[index_pairs], axis=0) # residue pairs - if residue1 != residue2: + graph = Graph(graph_id) - contact = ResidueContact(residue1, residue2) + for index1, index2 in index_pairs: + if index1 != index2: - node1 = Node(residue1) - node2 = Node(residue2) - edge = Edge(contact) + node1 = Node(nodes[index1]) + node2 = Node(nodes[index2]) + contact = NodeContact(node1.id, node2.id) - node1.features[Nfeat.POSITION] = residue1.get_center() - node2.features[Nfeat.POSITION] = residue2.get_center() + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - # The same residue will be added multiple times as a node, - # but the Graph class fixes this. - graph.add_node(node1) - graph.add_node(node2) - graph.add_edge(edge) + # The same node will be added multiple times, but the Graph class fixes this. + graph.add_node(node1) + graph.add_node(node2) + graph.add_edge(Edge(contact)) - return graph + return graph diff --git a/deeprank2/utils/parsing/pssm.py b/deeprank2/utils/parsing/pssm.py index d7130bffd..069dc1074 100644 --- a/deeprank2/utils/parsing/pssm.py +++ b/deeprank2/utils/parsing/pssm.py @@ -1,6 +1,6 @@ from typing import TextIO -from deeprank2.domain.aminoacidlist import amino_acids +from deeprank2.domain.aminoacidlist import amino_acids, amino_acids_by_letter from deeprank2.molstruct.residue import Residue from deeprank2.molstruct.structure import Chain from deeprank2.utils.pssmdata import PssmRow, PssmTable @@ -17,7 +17,6 @@ def parse_pssm(file_: TextIO, chain: Chain) -> PssmTable: PssmTable: The position-specific scoring table, parsed from the pssm file. """ - amino_acids_by_letter = {amino_acid.one_letter_code: amino_acid for amino_acid in amino_acids} conservation_rows = {} # Read the pssm header. diff --git a/tests/domain/test_aminoacidlist.py b/tests/domain/test_aminoacidlist.py index 2ab89c50a..7bafb2c89 100644 --- a/tests/domain/test_aminoacidlist.py +++ b/tests/domain/test_aminoacidlist.py @@ -12,15 +12,15 @@ ] def test_all_different_onehot(): - for amino_acid in amino_acids: - for other in amino_acids: - if other != amino_acid: - try: - assert not np.all(amino_acid.onehot == other.onehot) - except AssertionError as e: - if other in EXCEPTIONS[0] and amino_acid in EXCEPTIONS[0]: - assert np.all(amino_acid.onehot == other.onehot) - elif other in EXCEPTIONS[1] and amino_acid in EXCEPTIONS[1]: - assert np.all(amino_acid.onehot == other.onehot) - else: - raise AssertionError(f"one-hot index {amino_acid.index} is occupied by both {amino_acid} and {other}") from e + + for aa1, aa2 in zip(amino_acids, amino_acids): + if aa1 == aa2: + continue + + try: + assert not np.all(aa1.onehot == aa2.onehot) + except AssertionError as e: + if (aa1 in EXCEPTIONS[0] and aa2 in EXCEPTIONS[0]) or (aa1 in EXCEPTIONS[1] and aa2 in EXCEPTIONS[1]): + assert np.all(aa1.onehot == aa2.onehot) + else: + raise AssertionError(f"one-hot index {aa1.index} is occupied by both {aa1} and {aa2}") from e diff --git a/tests/features/__init__.py b/tests/features/__init__.py index eace0fbaf..c5c60aa90 100644 --- a/tests/features/__init__.py +++ b/tests/features/__init__.py @@ -9,8 +9,7 @@ from deeprank2.utils.buildgraph import (get_residue_contact_pairs, get_structure, get_surrounding_residues) -from deeprank2.utils.graph import (Graph, build_atomic_graph, - build_residue_graph) +from deeprank2.utils.graph import Graph from deeprank2.utils.parsing.pssm import parse_pssm @@ -25,7 +24,7 @@ def _get_residue(chain: Chain, number: int) -> Residue: def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noqa:MC0001 pdb_path: str, - detail: Literal['atomic', 'residue'], + detail: Literal['atom', 'residue'], influence_radius: float, max_edge_length: float, central_res: int | None = None, @@ -37,11 +36,11 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq Args: pdb_path (str): Path of pdb file. - detail (Literal['atomic', 'residue']): Level of detail. + detail (Literal['atom', 'residue']): Type of graph to create. influence_radius (float): max distance to include in graph. max_edge_length (float): max distance to create an edge. central_res (int | None, optional): Residue to center a single-chain graph around. - Use None to create a 2-chain graph, or any value for a single-chain graph + Use None to create a 2-chain graph, or any value for a single-chain graph. Defaults to None. variant (AminoAcid | None, optional): Amino acid to use as a variant amino acid. Defaults to None. @@ -52,7 +51,7 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq TypeError: if detail is set to anything other than 'residue' or 'atom' Returns: - Graph: As generated by build_residue_graph or build_atomic_graph + Graph: As generated by Graph.build_graph SingleResidueVariant: returns None if central_res is None """ @@ -62,12 +61,13 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq finally: pdb._close() # pylint: disable=protected-access - if not central_res: # pylint: disable=no-else-raise + if not central_res: nodes = set([]) if not chain_ids: chains = (structure.chains[0].id, structure.chains[1].id) else: chains = [structure.get_chain(chain_id) for chain_id in chain_ids] + for residue1, residue2 in get_residue_contact_pairs( pdb_path, structure, chains[0], chains[1], @@ -76,36 +76,33 @@ def build_testgraph( # pylint: disable=too-many-locals, too-many-arguments # noq if detail == 'residue': nodes.add(residue1) nodes.add(residue2) - elif detail == 'atom': for atom in residue1.atoms: nodes.add(atom) for atom in residue2.atoms: nodes.add(atom) + else: + raise TypeError('detail must be "atom" or "residue"') - if detail == 'residue': - return build_residue_graph(list(nodes), structure.id, max_edge_length), None - if detail == 'atom': - return build_atomic_graph(list(nodes), structure.id, max_edge_length), None - raise TypeError('detail must be "atom" or "residue"') + return Graph.build_graph(list(nodes), structure.id, max_edge_length), None + # if central_res + if not chain_ids: + chain: Chain = structure.chains[0] else: - if not chain_ids: - chain: Chain = structure.chains[0] - else: - chain = structure.get_chain(chain_ids) - residue = _get_residue(chain, central_res) - surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) - - try: - with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: - chain.pssm = parse_pssm(f, chain) - except FileNotFoundError: - pass - - if detail == 'residue': - return build_residue_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) - if detail == 'atom': - atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) - return build_atomic_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) - raise TypeError('detail must be "atom" or "residue"') + chain = structure.get_chain(chain_ids) + residue = _get_residue(chain, central_res) + surrounding_residues = list(get_surrounding_residues(structure, residue, influence_radius)) + + try: + with open(f"tests/data/pssm/{structure.id}/{structure.id}.{chain.id}.pdb.pssm", "rt", encoding="utf-8") as f: + chain.pssm = parse_pssm(f, chain) + except FileNotFoundError: + pass + + if detail == 'residue': + return Graph.build_graph(surrounding_residues, structure.id, max_edge_length), SingleResidueVariant(residue, variant) + if detail == 'atom': + atoms = set(atom for residue in surrounding_residues for atom in residue.atoms) + return Graph.build_graph(list(atoms), structure.id, max_edge_length), SingleResidueVariant(residue, variant) + raise TypeError('detail must be "atom" or "residue"') diff --git a/tests/test_query.py b/tests/test_query.py index c50fc558b..3415a30ca 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -7,12 +7,10 @@ import pytest from deeprank2.dataset import GraphDataset, GridDataset +from deeprank2.domain import aminoacidlist as aa from deeprank2.domain import edgestorage as Efeat from deeprank2.domain import nodestorage as Nfeat from deeprank2.domain import targetstorage as targets -from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine, - cysteine, glutamate, glycine, - leucine, lysine, phenylalanine) from deeprank2.features import components, conservation, contact, surfacearea from deeprank2.query import (ProteinProteinInterfaceQuery, QueryCollection, SingleResidueVariantQuery) @@ -146,8 +144,8 @@ def test_variant_graph_101M(): chain_ids="A", variant_residue_number=27, insertion_code=None, - wildtype_amino_acid=asparagine, - variant_amino_acid=phenylalanine, + wildtype_amino_acid=aa.asparagine, + variant_amino_acid=aa.phenylalanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, influence_radius=5.0, @@ -179,8 +177,8 @@ def test_variant_graph_1A0Z(): chain_ids="A", variant_residue_number=125, insertion_code=None, - wildtype_amino_acid=leucine, - variant_amino_acid=arginine, + wildtype_amino_acid=aa.leucine, + variant_amino_acid=aa.arginine, pssm_paths={ "A": "tests/data/pssm/1A0Z/1A0Z.A.pdb.pssm", "B": "tests/data/pssm/1A0Z/1A0Z.B.pdb.pssm", @@ -217,8 +215,8 @@ def test_variant_graph_9API(): chain_ids="A", variant_residue_number=310, insertion_code=None, - wildtype_amino_acid=lysine, - variant_amino_acid=glutamate, + wildtype_amino_acid=aa.lysine, + variant_amino_acid=aa.glutamate, pssm_paths={ "A": "tests/data/pssm/9api/9api.A.pdb.pssm", "B": "tests/data/pssm/9api/9api.B.pdb.pssm", @@ -253,8 +251,8 @@ def test_variant_residue_graph_101M(): chain_ids="A", variant_residue_number=25, insertion_code=None, - wildtype_amino_acid=glycine, - variant_amino_acid=alanine, + wildtype_amino_acid=aa.glycine, + variant_amino_acid=aa.alanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, ) @@ -315,8 +313,8 @@ def test_augmentation(): chain_ids="A", variant_residue_number=25, insertion_code=None, - wildtype_amino_acid=glycine, - variant_amino_acid=alanine, + wildtype_amino_acid=aa.glycine, + variant_amino_acid=aa.alanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, )) @@ -327,8 +325,8 @@ def test_augmentation(): chain_ids="A", variant_residue_number=27, insertion_code=None, - wildtype_amino_acid=asparagine, - variant_amino_acid=phenylalanine, + wildtype_amino_acid=aa.asparagine, + variant_amino_acid=aa.phenylalanine, pssm_paths={"A": "tests/data/pssm/101M/101M.A.pdb.pssm"}, targets={targets.BINARY: 0}, influence_radius=3.0, @@ -470,8 +468,8 @@ def test_variant_query_multiple_chains(): chain_ids = "A", variant_residue_number = 14, insertion_code = None, - wildtype_amino_acid = arginine, - variant_amino_acid = cysteine, + wildtype_amino_acid = aa.arginine, + variant_amino_acid = aa.cysteine, pssm_paths = {"A": "tests/data/pssm/2g98/2g98.A.pdb.pssm"}, targets = {targets.BINARY: 1}, influence_radius = 10.0, diff --git a/tutorials/data_generation_srv.ipynb b/tutorials/data_generation_srv.ipynb index 2e7be5c8d..f8c3ffddf 100644 --- a/tutorials/data_generation_srv.ipynb +++ b/tutorials/data_generation_srv.ipynb @@ -72,23 +72,10 @@ "import matplotlib.pyplot as plt\n", "from deeprank2.query import QueryCollection\n", "from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n", - "from deeprank2.domain.aminoacidlist import (alanine, arginine, asparagine,\n", - " serine, glycine, leucine, aspartate,\n", - " glutamine, glutamate, lysine, phenylalanine, histidine,\n", - " tyrosine, tryptophan, valine, proline,\n", - " cysteine, isoleucine, methionine, threonine)\n", + "from deeprank2.domain.aminoacidlist import amino_acids_by_code\n", "from deeprank2.features import components, contact\n", "from deeprank2.utils.grid import GridSettings, MapMethod\n", - "from deeprank2.dataset import GraphDataset\n", - "\n", - "aa_dict = {\"ALA\": alanine, \"CYS\": cysteine, \"ASP\": aspartate,\n", - " \"GLU\": glutamate, \"PHE\": phenylalanine, \"GLY\": glycine, \n", - " \"HIS\": histidine, \"ILE\": isoleucine, \"LYS\": lysine,\n", - " \"LEU\": leucine, \"MET\": methionine, \"ASN\": asparagine,\n", - " \"PRO\": proline, \"GLN\": glutamine, \"ARG\": arginine,\n", - " \"SER\": serine, \"THR\": threonine, \"VAL\": valine,\n", - " \"TRP\": tryptophan, \"TYR\": tyrosine\n", - " }" + "from deeprank2.dataset import GraphDataset\n" ] }, { @@ -220,8 +207,8 @@ "\t\tchain_ids = \"A\",\n", "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", + "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", + "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", "\t\ttargets = {'binary': targets[i]},\n", "\t\tinfluence_radius = influence_radius,\n", "\t\tmax_edge_length = max_edge_length,\n", @@ -464,8 +451,8 @@ "\t\tchain_ids = \"A\",\n", "\t\tvariant_residue_number = res_numbers[i],\n", "\t\tinsertion_code = None,\n", - "\t\twildtype_amino_acid = aa_dict[res_wildtypes[i]],\n", - "\t\tvariant_amino_acid = aa_dict[res_variants[i]],\n", + "\t\twildtype_amino_acid = amino_acids_by_code[res_wildtypes[i]],\n", + "\t\tvariant_amino_acid = amino_acids_by_code[res_variants[i]],\n", "\t\ttargets = {'binary': targets[i]},\n", "\t\tinfluence_radius = influence_radius,\n", "\t\tmax_edge_length = max_edge_length,\n",