diff --git a/deeprank2/molstruct/residue.py b/deeprank2/molstruct/residue.py index e9938b219..552d1a015 100644 --- a/deeprank2/molstruct/residue.py +++ b/deeprank2/molstruct/residue.py @@ -97,7 +97,7 @@ def __repr__(self) -> str: @property def position(self) -> np.array: - return np.mean([atom.position for atom in self._atoms], axis=0) + return self.get_center() def get_center(self) -> NDArray: """Find the center position of a `Residue`. diff --git a/deeprank2/utils/graph.py b/deeprank2/utils/graph.py index 174eeb41c..4c51624e5 100644 --- a/deeprank2/utils/graph.py +++ b/deeprank2/utils/graph.py @@ -59,7 +59,6 @@ def __init__(self, id_: Atom | Residue): raise TypeError(type(id_)) self.id = id_ - self.features = {} @property @@ -319,8 +318,10 @@ def get_all_chains(self) -> list[str]: return list(chains) -def build_atomic_graph( # pylint: disable=too-many-locals - atoms: list[Atom], graph_id: str, max_edge_distance: float +def build_atomic_graph( + atoms: list[Atom], + graph_id: str, + max_edge_distance: float, ) -> Graph: """Builds a graph, using the atoms as nodes. @@ -333,19 +334,18 @@ def build_atomic_graph( # pylint: disable=too-many-locals distances = distance_matrix(positions, positions, p=2) neighbours = distances < max_edge_distance + atom_index_pairs = np.transpose(np.nonzero(neighbours)) graph = Graph(graph_id) - for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)): - if atom1_index != atom2_index: + for index1, index2 in atom_index_pairs: + if index1 != index2: - atom1 = atoms[atom1_index] - atom2 = atoms[atom2_index] - contact = AtomicContact(atom1, atom2) + node1 = Node(atoms[index1]) + node2 = Node(atoms[index2]) + contact = AtomicContact(node1.id, node2.id) - node1 = Node(atom1) - node2 = Node(atom2) - node1.features[Nfeat.POSITION] = atom1.position - node2.features[Nfeat.POSITION] = atom2.position + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position graph.add_node(node1) graph.add_node(node2) @@ -355,7 +355,9 @@ def build_atomic_graph( # pylint: disable=too-many-locals def build_residue_graph( # pylint: disable=too-many-locals - residues: list[Residue], graph_id: str, max_edge_distance: float + residues: list[Residue], + graph_id: str, + max_edge_distance: float, ) -> Graph: """Builds a graph, using the residues as nodes. @@ -379,8 +381,6 @@ def build_residue_graph( # pylint: disable=too-many-locals positions[atom_index] = atom.position distances = distance_matrix(positions, positions, p=2) - - # determine which atoms are close enough neighbours = distances < max_edge_distance atom_index_pairs = np.transpose(np.nonzero(neighbours)) @@ -390,26 +390,20 @@ def build_residue_graph( # pylint: disable=too-many-locals # build the graph graph = Graph(graph_id) - for residue1_index, residue2_index in residue_index_pairs: - - residue1: Residue = residues[residue1_index] - residue2: Residue = residues[residue2_index] + for index1, index2 in residue_index_pairs: + if index1 != index2: - if residue1 != residue2: + node1 = Node(residues[index1]) + node2 = Node(residues[index2]) + contact = ResidueContact(node1.id, node2.id) - contact = ResidueContact(residue1, residue2) + node1.features[Nfeat.POSITION] = node1.id.position + node2.features[Nfeat.POSITION] = node2.id.position - node1 = Node(residue1) - node2 = Node(residue2) - edge = Edge(contact) - - node1.features[Nfeat.POSITION] = residue1.get_center() - node2.features[Nfeat.POSITION] = residue2.get_center() - - # The same residue will be added multiple times as a node, + # The same residue will be added multiple times as a node, # but the Graph class fixes this. graph.add_node(node1) graph.add_node(node2) - graph.add_edge(edge) + graph.add_edge(Edge(contact)) return graph