Skip to content

Commit

Permalink
make build_atomic_graph and build_residue_graph look similar
Browse files Browse the repository at this point in the history
this is a preparation for the next commit, where the functions are unified.
here only variables are renamed and spacing is changed so that it is easy to see what
the similarities and differences between the functions are.
  • Loading branch information
DaniBodor committed Nov 18, 2023
1 parent 6825a02 commit 5675fbf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
2 changes: 1 addition & 1 deletion deeprank2/molstruct/residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __repr__(self) -> str:

@property
def position(self) -> np.array:
return np.mean([atom.position for atom in self._atoms], axis=0)
return self.get_center()

def get_center(self) -> NDArray:
"""Find the center position of a `Residue`.
Expand Down
56 changes: 25 additions & 31 deletions deeprank2/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, id_: Atom | Residue):
raise TypeError(type(id_))

self.id = id_

self.features = {}

@property
Expand Down Expand Up @@ -319,8 +318,10 @@ def get_all_chains(self) -> list[str]:
return list(chains)


def build_atomic_graph( # pylint: disable=too-many-locals
atoms: list[Atom], graph_id: str, max_edge_length: float
def build_atomic_graph(
atoms: list[Atom],
graph_id: str,
max_edge_length: float,
) -> Graph:
"""Builds a graph, using the atoms as nodes.
Expand All @@ -333,19 +334,18 @@ def build_atomic_graph( # pylint: disable=too-many-locals

distances = distance_matrix(positions, positions, p=2)
neighbours = distances < max_edge_length
atom_index_pairs = np.transpose(np.nonzero(neighbours))

graph = Graph(graph_id)
for atom1_index, atom2_index in np.transpose(np.nonzero(neighbours)):
if atom1_index != atom2_index:
for index1, index2 in atom_index_pairs:
if index1 != index2:

atom1 = atoms[atom1_index]
atom2 = atoms[atom2_index]
contact = AtomicContact(atom1, atom2)
node1 = Node(atoms[index1])
node2 = Node(atoms[index2])
contact = AtomicContact(node1.id, node2.id)

node1 = Node(atom1)
node2 = Node(atom2)
node1.features[Nfeat.POSITION] = atom1.position
node2.features[Nfeat.POSITION] = atom2.position
node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

graph.add_node(node1)
graph.add_node(node2)
Expand All @@ -355,11 +355,13 @@ def build_atomic_graph( # pylint: disable=too-many-locals


def build_residue_graph( # pylint: disable=too-many-locals
residues: list[Residue], graph_id: str, max_edge_length: float
residues: list[Residue],
graph_id: str,
max_edge_length: float,
) -> Graph:
"""Builds a graph, using the residues as nodes.
The max edge distance is in Ångströms.
The max edge length is in Ångströms.
It's the shortest interatomic distance between two residues.
"""

Expand All @@ -379,8 +381,6 @@ def build_residue_graph( # pylint: disable=too-many-locals
positions[atom_index] = atom.position

distances = distance_matrix(positions, positions, p=2)

# determine which atoms are close enough
neighbours = distances < max_edge_length

atom_index_pairs = np.transpose(np.nonzero(neighbours))
Expand All @@ -390,26 +390,20 @@ def build_residue_graph( # pylint: disable=too-many-locals

# build the graph
graph = Graph(graph_id)
for residue1_index, residue2_index in residue_index_pairs:
for index1, index2 in residue_index_pairs:
if index1 != index2:

residue1: Residue = residues[residue1_index]
residue2: Residue = residues[residue2_index]
node1 = Node(residues[index1])
node2 = Node(residues[index2])
contact = ResidueContact(node1.id, node2.id)

if residue1 != residue2:
node1.features[Nfeat.POSITION] = node1.id.position
node2.features[Nfeat.POSITION] = node2.id.position

contact = ResidueContact(residue1, residue2)

node1 = Node(residue1)
node2 = Node(residue2)
edge = Edge(contact)

node1.features[Nfeat.POSITION] = residue1.get_center()
node2.features[Nfeat.POSITION] = residue2.get_center()

# The same residue will be added multiple times as a node,
# The same residue will be added multiple times as a node,
# but the Graph class fixes this.
graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(edge)
graph.add_edge(Edge(contact))

return graph

0 comments on commit 5675fbf

Please sign in to comment.