Skip to content

Commit

Permalink
make atomic and residue graph similar
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Sep 22, 2023
1 parent 5ab1247 commit 17df418
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 31 deletions.
2 changes: 1 addition & 1 deletion deeprank2/molstruct/residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __repr__(self) -> str:

@property
def position(self) -> np.array:
return np.mean([atom.position for atom in self._atoms], axis=0)
return self.get_center()

def get_center(self) -> NDArray:
"""Find the center position of a `Residue`.
Expand Down
54 changes: 24 additions & 30 deletions deeprank2/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, id_: Atom | Residue):
raise TypeError(type(id_))

self.id = id_

self.features = {}

@property
Expand Down Expand Up @@ -319,8 +318,10 @@ def get_all_chains(self) -> list[str]:
return list(chains)


def build_atomic_graph( # pylint: disable=too-many-locals
atoms: list[Atom], graph_id: str, max_edge_distance: float
def build_atomic_graph(
atoms: list[Atom],
graph_id: str,
max_edge_distance: float,
) -> Graph:
"""Builds a graph, using the atoms as nodes.
Expand All @@ -333,19 +334,18 @@ def build_atomic_graph( # pylint: disable=too-many-locals

distances = distance_matrix(positions, positions, p=2)
neighbours = distances < max_edge_distance
atom_index_pairs = np.transpose(np.nonzero(neighbours))

graph = Graph(graph_id)
for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)):
if atom1_index != atom2_index:
for index1, index2 in atom_index_pairs:
if index1 != index2:

atom1 = atoms[atom1_index]
atom2 = atoms[atom2_index]
contact = AtomicContact(atom1, atom2)
node1 = Node(atoms[index1])
node2 = Node(atoms[index2])
contact = AtomicContact(node1.id, node2.id)

node1 = Node(atom1)
node2 = Node(atom2)
node1.features[Nfeat.POSITION] = atom1.position
node2.features[Nfeat.POSITION] = atom2.position
node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

graph.add_node(node1)
graph.add_node(node2)
Expand All @@ -355,7 +355,9 @@ def build_atomic_graph( # pylint: disable=too-many-locals


def build_residue_graph( # pylint: disable=too-many-locals
residues: list[Residue], graph_id: str, max_edge_distance: float
residues: list[Residue],
graph_id: str,
max_edge_distance: float,
) -> Graph:
"""Builds a graph, using the residues as nodes.
Expand All @@ -379,8 +381,6 @@ def build_residue_graph( # pylint: disable=too-many-locals
positions[atom_index] = atom.position

distances = distance_matrix(positions, positions, p=2)

# determine which atoms are close enough
neighbours = distances < max_edge_distance

atom_index_pairs = np.transpose(np.nonzero(neighbours))
Expand All @@ -390,26 +390,20 @@ def build_residue_graph( # pylint: disable=too-many-locals

# build the graph
graph = Graph(graph_id)
for residue1_index, residue2_index in residue_index_pairs:

residue1: Residue = residues[residue1_index]
residue2: Residue = residues[residue2_index]
for index1, index2 in residue_index_pairs:
if index1 != index2:

if residue1 != residue2:
node1 = Node(residues[index1])
node2 = Node(residues[index2])
contact = ResidueContact(node1.id, node2.id)

contact = ResidueContact(residue1, residue2)
node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

node1 = Node(residue1)
node2 = Node(residue2)
edge = Edge(contact)

node1.features[Nfeat.POSITION] = residue1.get_center()
node2.features[Nfeat.POSITION] = residue2.get_center()

# The same residue will be added multiple times as a node,
# The same residue will be added multiple times as a node,
# but the Graph class fixes this.
graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(edge)
graph.add_edge(Edge(contact))

return graph

0 comments on commit 17df418

Please sign in to comment.