Skip to content

Commit

Permalink
Merge pull request #69 from Genentech/tangermeme_04
Browse files Browse the repository at this point in the history
  • Loading branch information
avantikalal authored Nov 4, 2024
2 parents 8b1431e + 7317b33 commit a1c3be4
Show file tree
Hide file tree
Showing 11 changed files with 872 additions and 235 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ RUN pip install black flake8 isort
RUN pip install captum==0.5.0 wandb tensorboard plotnine

RUN pip install bioframe biopython genomepy scanpy \
pyjaspar pymemesuite pyBigWig pyfaidx pytabix
pyjaspar pyBigWig pyfaidx pytabix
RUN pip install bpnet-lite>=0.5.7 ledidi enformer-pytorch genomepy
RUN pip install pygenomeviz

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ install_requires =
plotnine >= 0.8
anndata >= 0.8
scikit-learn
pymemesuite
torch >= 2.0
pytorch-lightning >= 2.0
torchmetrics >= 1.1
Expand Down
174 changes: 97 additions & 77 deletions src/grelu/interpret/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,84 @@
Functions related to manipulating sequence motifs and scanning DNA sequences with motifs.
"""

from typing import Callable, Generator, List, Optional, Tuple, Union
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from pymemesuite.common import Motif
from torch import Tensor

from grelu.io.motifs import read_meme_file
from grelu.utils import make_list


def motifs_to_strings(
motifs: Union[Motif, List[Motif], str],
motifs: Union[np.ndarray, Dict[str, np.ndarray], str],
names: Optional[List[str]] = None,
sample: bool = False,
rng: Optional[Generator] = None,
) -> str:
"""
Extracts a matching DNA sequence from a motif
Extracts a matching DNA sequence from a motif. If sample=True, the best match sequence
is returned, otherwise a sequence is sampled from the probability distribution at each
position of the motif.
Args:
motifs: A pymemesuite.common.Motif object, a list of such objects,
or the path to a MEME file.
names: A list of motif names to read from the MEME file, in case a MEME
file is supplied in motifs. If None, all motifs in the file will be read.
motifs: Either a numpy array containing a Position Probability
Matrix (PPM) of shape (4, L), or a dictionary containing
motif names as keys and PPMs of shape (4, L) as values, or the
path to a MEME file.
names: A list of motif names to read from the MEME file, in case a
MEME file is supplied in motifs. If None, all motifs in the
file will be read.
sample: If True, a sequence will be sampled from the motif.
Otherwise, the best match sequence will be returned.
rng: np.random.RandomState object
Returns:
DNA sequence(s) as strings
"""
from grelu.io.meme import read_meme_file
from grelu.sequence.format import indices_to_strings

# Set random seed
rng = rng or np.random.RandomState(seed=None)

# Convert a single motif
if isinstance(motifs, Motif):
# Extract probabilities
probs = np.array(motifs.frequencies)
if isinstance(motifs, np.ndarray):

# Extract sequence as indices
if sample:
indices = np.array(
[rng.multinomial(1, p).argmax() for p in probs], dtype=np.int8
[rng.multinomial(1, pos).argmax() for pos in motifs.T], dtype=np.int8
)
else:
indices = probs.argmax(1).astype(np.int8)
indices = motifs.argmax(0).astype(np.int8)

# Return strings
return indices_to_strings(indices)

# Convert multiple motifs
elif isinstance(motifs, List):
return [motifs_to_strings(motif, rng=rng, sample=sample) for motif in motifs]
elif isinstance(motifs, Dict):
return [
motifs_to_strings(motif, rng=rng, sample=sample)
for motif in motifs.values()
]
else:
motifs, _ = read_meme_file(motifs, names=make_list(names))
motifs = read_meme_file(motifs, names=make_list(names))
return motifs_to_strings(motifs, rng=rng, sample=sample)


def trim_pwm(
pwm: np.array,
pwm: np.ndarray,
trim_threshold: float = 0.3,
padding: int = 0,
return_indices: bool = False,
) -> Union[Tuple[int], np.array]:
) -> Union[Tuple[int], np.ndarray]:
"""
Trims the edges of a PWM based on information content.
Trims the edges of a Position Weight Matrix (PWM) based on the
information content of each position.
Args:
pwm: PWM array of shape (L, 4)
pwm: A numpy array of shape (4, L) containing the PWM
trim_threshold: Threshold ranging from 0 to 1 to trim edge positions
padding: Number of low-information positions on either end to allow
return_indices: If True, only the indices of the positions to keep
will be returned. If False, the trimmed motif will be returned.
Expand All @@ -84,7 +89,7 @@ def trim_pwm(
(if return_indices = False).
"""
# Get per position score
score = np.sum(np.abs(pwm), axis=1)
score = np.sum(np.abs(pwm), axis=0)

# Calculate score threshold
trim_thresh = np.max(score) * trim_threshold
Expand All @@ -93,81 +98,100 @@ def trim_pwm(
pass_inds = np.where(score >= trim_thresh)[0]

# Get the start and end of the trimmed motif
start = max(np.min(pass_inds) - padding, 0)
end = min(np.max(pass_inds) + padding + 1, len(score) + 1)
start = max(np.min(pass_inds), 0)
end = min(np.max(pass_inds) + 1, len(score) + 1)

if return_indices:
return start, end
else:
return pwm[start:end]
return pwm[:, start:end]


def scan_sequences(
seqs: List[str],
motifs: Union[Motif, List[Motif], str],
seqs: Union[str, List[str]],
motifs: Union[str, Dict[str, np.ndarray]],
names: Optional[List[str]] = None,
bg=None,
seq_ids: Optional[List[str]] = None,
pthresh: float = 1e-3,
rc: bool = True,
bin_size=0.1,
eps=0.0001,
):
"""
Scan a DNA sequence using motifs
Scan a DNA sequence using motifs. Based on
https://github.com/jmschrei/tangermeme/blob/main/tangermeme/tools/fimo.py.
Args:
seqs: A list of DNA sequences as strings
motifs: A list of pymemesuite.common.Motif objects,
or the path to a MEME file.
seqs: A string or a list of DNA sequences as strings
motifs: A dictionary whose values are Position Probability Matrices
(PPMs) of shape (4, L), or the path to a MEME file.
names: A list of motif names to read from the MEME file.
If None, all motifs in the file will be read.
bg: A background distribution for motif p-value calculations.
Only needed if a list of Motif objects is supplied instead
of a MEME file.
seq_ids: Optional list of IDs for sequences
pthresh: p-value cutoff for binding sites
rc: If True, both the sequence and its reverse complement will be
scanned. If False, only the given sequence will be scanned.
bin_size: The size of the bins discretizing the PWM scores. The smaller
the bin size the higher the resolution, but the less data may be
available to support it. Default is 0.1.
eps: A small pseudocount to add to the motif PPMs before taking the log.
Default is 0.0001.
Returns:
pd.DataFrame containing columns 'motif', 'sequence', 'start', 'end',
'strand', 'score' and 'pval'.
'strand', 'score', 'pval', and 'matched_seq'.
"""
from collections import defaultdict
from tangermeme.tools.fimo import fimo

from pymemesuite.common import Sequence
from pymemesuite.fimo import FIMO

from grelu.io.meme import read_meme_file

# Load motifs
if isinstance(motifs, str):
motifs, bg = read_meme_file(motifs, names=names)
from grelu.sequence.format import strings_to_one_hot

# Format sequences
seqs = make_list(seqs)
if seq_ids is None:
seq_ids = [str(i) for i in range(len(seqs))]
sequences = [Sequence(seq, name=id.encode()) for id, seq in zip(seq_ids, seqs)]

# Setup FIMO
fimo = FIMO(both_strands=rc, threshold=pthresh)

# Empty dictionary for output
out = defaultdict(list)

# Scan
for motif in motifs:
match = fimo.score_motif(motif, sequences, bg).matched_elements
for m in match:
out["motif"].append(motif.name.decode())
out["sequence"].append(m.source.accession.decode())
out["start"].append(m.start)
out["end"].append(m.stop)
out["strand"].append(m.strand)
out["score"].append(m.score)
out["pval"].append(m.pvalue)

return pd.DataFrame(out)
seq_ids = seq_ids or [str(i) for i in range(len(seqs))]

# Format motifs
if isinstance(motifs, Dict):
motifs = {k: Tensor(v) for k, v in motifs.items()}

# Scan each sequence in seqs
results = pd.DataFrame()
for seq, seq_id in zip(seqs, seq_ids):
one_hot = strings_to_one_hot(seq, add_batch_axis=True)
curr_results = fimo(
motifs,
sequences=one_hot,
alphabet=["A", "C", "G", "T"],
bin_size=bin_size,
eps=eps,
threshold=pthresh,
reverse_complement=rc,
dim=1,
)
if len(curr_results) == 1:
curr_results = curr_results[0]
curr_results["sequence"] = seq_id
curr_results["matched_seq"] = curr_results.apply(
lambda row: seq[row.start : row.end], axis=1
)
curr_results = curr_results[
[
"motif_name",
"sequence",
"start",
"end",
"strand",
"score",
"p-value",
"matched_seq",
]
]
results = pd.concat([results, curr_results])

# Concatenate results from all sequences
if len(results) > 0:
results = results.reset_index(drop=True)
results = results.rename(columns={"motif_name": "motif"})
return results


def marginalize_patterns(
Expand Down Expand Up @@ -252,12 +276,11 @@ def marginalize_patterns(

def compare_motifs(
ref_seq: Union[str, pd.DataFrame],
motifs: Union[Motif, List[Motif], str],
motifs: Union[str, np.ndarray, Dict[str, np.ndarray]],
alt_seq: Optional[str] = None,
alt_allele: Optional[str] = None,
pos: Optional[int] = None,
names: Optional[List[str]] = None,
bg=None,
pthresh: float = 1e-3,
rc: bool = True,
) -> pd.DataFrame:
Expand All @@ -267,8 +290,8 @@ def compare_motifs(
Args:
ref_seq: The reference sequence as a string
motifs: A list of pymemesuite.common.Motif objects,
or the path to a MEME file.
motifs: A dictionary whose values are Position Probability Matrices
(PPMs) of shape (4, L), or the path to a MEME file.
alt_seq: The alternate sequence as a string
ref_allele: The alternate allele as a string. Only used if
alt_seq is not supplied.
Expand All @@ -278,9 +301,6 @@ def compare_motifs(
Only needed if alt_seq is not supplied.
names: A list of motif names to read from the MEME file.
If None, all motifs in the file will be read.
bg: A background distribution for motif p-value calculations.
Only needed if a list of Motif objects is supplied instead
of a MEME file.
pthresh: p-value cutoff for binding sites
rc: If True, both the sequence and its reverse complement will be
scanned. If False, only the given sequence will be scanned.
Expand Down
Loading

0 comments on commit a1c3be4

Please sign in to comment.