Skip to content

Commit

Permalink
refactor: finding residue from pdb2sql key
Browse files Browse the repository at this point in the history
by using a helper function rather than repeating code
  • Loading branch information
DaniBodor committed Nov 17, 2023
1 parent a6cc17b commit 6461cfe
Showing 1 changed file with 27 additions and 40 deletions.
67 changes: 27 additions & 40 deletions deeprank2/utils/buildgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def get_contact_atoms(

return structure.get_atoms()

def get_residue_contact_pairs( # pylint: disable=too-many-locals

def get_residue_contact_pairs(
pdb_path: str,
structure: PDBStructure,
chain_id1: str,
Expand Down Expand Up @@ -153,50 +154,36 @@ def get_residue_contact_pairs( # pylint: disable=too-many-locals

# Map to residue objects
residue_pairs = set([])
for residue_key1, _ in contact_residues.items():
residue_chain_id1, residue_number1, residue_name1 = residue_key1

chain1 = structure.get_chain(residue_chain_id1)

residue1 = None
for residue in chain1.residues:
if (
residue.number == residue_number1
and residue.amino_acid is not None
and residue.amino_acid.three_letter_code == residue_name1
):
residue1 = residue
break
else:
raise ValueError(
f"Not found: {pdb_path} {residue_chain_id1} {residue_number1} {residue_name1}"
)

for residue_chain_id2, residue_number2, residue_name2 in contact_residues[ # pylint: disable=unnecessary-dict-index-lookup
residue_key1
]:

chain2 = structure.get_chain(residue_chain_id2)

residue2 = None
for residue in chain2.residues:
if (
residue.number == residue_number2
and residue.amino_acid is not None
and residue.amino_acid.three_letter_code == residue_name2
):
residue2 = residue
break
else:
raise ValueError(
f"Not found: {pdb_path} {residue_chain_id2} {residue_number2} {residue_name2}"
)

for residue_key1, residue_contacts in contact_residues.items():
residue1 = _get_residue_from_key(structure, residue_key1)
for residue_key2 in residue_contacts:
residue2 = _get_residue_from_key(structure, residue_key2)
residue_pairs.add(Pair(residue1, residue2))

return residue_pairs


def _get_residue_from_key(
structure: PDBStructure,
residue_key: tuple[str, int, str],
) -> Residue:
"""Returns a residue object given a pdb2sql-formatted residue key."""

residue_chain_id, residue_number, residue_name = residue_key
chain = structure.get_chain(residue_chain_id)

for residue in chain.residues:
if (
residue.number == residue_number
and residue.amino_acid is not None
and residue.amino_acid.three_letter_code == residue_name
):
return residue
raise ValueError(
f"Residue ({residue_key}) not found in {structure.id}."
)


def get_surrounding_residues(
structure: Chain | PDBStructure,
residue: Residue,
Expand Down

0 comments on commit 6461cfe

Please sign in to comment.