From 2708c5cd08a642f113b7c61f5ef6fdb34da9b590 Mon Sep 17 00:00:00 2001 From: Dani Bodor Date: Fri, 22 Sep 2023 22:22:06 +0200 Subject: [PATCH] style: make build_atomic_graph and build_residue_graph look similar this is a preparation for the next commit, where the functions are unified. here only variables are renamed and spacing is changed so that it is easy to see what the similarities and differences between the functions are. --- deeprank2/molstruct/residue.py | 2 +- deeprank2/utils/graph.py | 56 +++++++++++++++------------------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index e9938b219..552d1a015 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -97,7 +97,7 @@ def __repr__(self) -> str: @property def position(self) -> np.array: - return np.mean([atom.position for atom in self._atoms], axis=0) + return self.get_center() def get_center(self) -> NDArray: """Find the center position of a `Residue`. diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 4ab0546e1..593d0dab0 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -59,7 +59,6 @@ def __init__(self, id_: Atom | Residue): raise TypeError(type(id_)) self.id = id_ - self.features = {} @property @@ -319,8 +318,10 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( # pylint: disable=too-many-locals - atoms: list[Atom], graph_id: str, max_edge_length: float +def build_atomic_graph( + atoms: list[Atom], + graph_id: str, + max_edge_length: float, ) -> Graph: """Builds a graph, using the atoms as nodes. @@ -333,19 +334,18 @@ def build_atomic_graph( # pylint: disable=too-many-locals distances = distance_matrix(positions, positions, p=2) neighbours = distances < max_edge_length + atom_index_pairs = np.transpose(np.nonzero(neighbours)) graph = Graph(graph_id) - for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)): - if atom1_index != atom2_index: + for index1, index2 in atom_index_pairs: + if index1 != index2: - atom1 = atoms[atom1_index] - atom2 = atoms[atom2_index] - contact = AtomicContact(atom1, atom2) + node1 = Node(atoms[index1]) + node2 = Node(atoms[index2]) + contact = AtomicContact(node1.id, node2.id) - node1 = Node(atom1) - node2 = Node(atom2) - node1.features[Nfeat.POSITION] = atom1.position - node2.features[Nfeat.POSITION] = atom2.position + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position graph.add_node(node1) graph.add_node(node2) @@ -355,11 +355,13 @@ def build_atomic_graph( # pylint: disable=too-many-locals def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], graph_id: str, max_edge_length: float + residues: list[Residue], + graph_id: str, + max_edge_length: float, ) -> Graph: """Builds a graph, using the residues as nodes. - The max edge distance is in Ångströms. + The max edge length is in Ångströms. It's the shortest interatomic distance between two residues. """ @@ -379,8 +381,6 @@ def build_residue_graph( # pylint: disable=too-many-locals positions[atom_index] = atom.position distances = distance_matrix(positions, positions, p=2) - - # determine which atoms are close enough neighbours = distances < max_edge_length atom_index_pairs = np.transpose(np.nonzero(neighbours)) @@ -390,26 +390,20 @@ def build_residue_graph( # pylint: disable=too-many-locals # build the graph graph = Graph(graph_id) - for residue1_index, residue2_index in residue_index_pairs: + for index1, index2 in residue_index_pairs: + if index1 != index2: - residue1: Residue = residues[residue1_index] - residue2: Residue = residues[residue2_index] + node1 = Node(residues[index1]) + node2 = Node(residues[index2]) + contact = ResidueContact(node1.id, node2.id) - if residue1 != residue2: + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - contact = ResidueContact(residue1, residue2) - - node1 = Node(residue1) - node2 = Node(residue2) - edge = Edge(contact) - - node1.features[Nfeat.POSITION] = residue1.get_center() - node2.features[Nfeat.POSITION] = residue2.get_center() - - # The same residue will be added multiple times as a node, + # The same residue will be added multiple times as a node, # but the Graph class fixes this. graph.add_node(node1) graph.add_node(node2) - graph.add_edge(edge) + graph.add_edge(Edge(contact)) return graph