Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace bpnetlite with tangermeme #15

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ install_requires =
genomepy
bioframe >= 0.4
captum == 0.5.0
bpnet-lite == 0.5.7
logomaker >= 0.8
pyBigWig
ledidi
Expand Down
30 changes: 9 additions & 21 deletions src/grelu/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import os
import subprocess
import tempfile
from typing import Callable, List, Optional, Union

import bioframe as bf
Expand Down Expand Up @@ -479,7 +478,6 @@ def get_gc_matched_intervals(
genome: str,
binwidth: float = 0.1,
chroms: str = "autosomes",
gc_bw_file: str = None,
blacklist: str = "hg38",
seed: Optional[int] = None,
) -> pd.DataFrame:
Expand All @@ -491,15 +489,13 @@ def get_gc_matched_intervals(
genome: Name of the genome corresponding to intervals
binwidth: Resolution of GC content
chroms: Chromosomes to search for matched intervals
gc_bw_file: Path to a bigWig file of genomewide GC content.
If None, will be created.
blacklist: Blacklist file of regions to exclude
seed: Random seed

Returns:
A pandas dataframe containing GC-matched negative intervals.
"""
from bpnetlite.negatives import calculate_gc_genomewide, extract_matching_loci
from tangermeme.match import extract_matching_loci

from grelu.io.genome import get_genome
from grelu.sequence.utils import get_unique_length
Expand All @@ -510,25 +506,17 @@ def get_gc_matched_intervals(
# Get seq_len
seq_len = get_unique_length(intervals)

# Get bigWig file of GC content
if gc_bw_file is None:
gc_bw_file = "gc_{}_{}.bw".format(genome.name, seq_len)
print("Calculating GC content genomewide and saving to {}".format(gc_bw_file))
calculate_gc_genomewide(
fasta=genome.genome_file,
bigwig=gc_bw_file,
width=seq_len,
include_chroms=chroms,
verbose=True,
)

print("Extracting matching intervals")
_, tmpfile = tempfile.mkstemp()
intervals.iloc[:, :3].to_csv(tmpfile, sep="\t", index=False, header=False)
matched_loci = extract_matching_loci(
bed=tmpfile, bigwig=gc_bw_file, width=seq_len, bin_width=binwidth, verbose=True
intervals,
fasta=genome.genome_file,
in_window=seq_len,
gc_bin_width=binwidth,
chroms=chroms,
verbose=False,
random_state=seed,
)
os.remove(tmpfile)

print("Filtering blacklist")
if blacklist is not None:
matched_loci = filter_blacklist(matched_loci, blacklist)
Expand Down
1 change: 1 addition & 0 deletions src/grelu/interpret/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def marginalize_patterns(
genome=genome,
rc=rc,
n_shuffles=n_shuffles,
seed=seed,
)

# Get predictions on the sequences before motif insertion
Expand Down
16 changes: 12 additions & 4 deletions src/grelu/sequence/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,15 @@ def strings_to_indices(
).astype(np.int8)


def indices_to_one_hot(indices: np.ndarray) -> Tensor:
def indices_to_one_hot(indices: np.ndarray, add_batch_axis: bool = False) -> Tensor:
"""
Convert integer-encoded DNA sequences to one-hot encoded format.

Args:
indices: Integer-encoded DNA sequences.
add_batch_axis: If True, a batch axis will be included in the output for single
sequences. If False, the output for a single sequence will be a 2-dimensional
tensor.

Returns:
The one-hot encoded sequences.
Expand All @@ -271,9 +274,12 @@ def indices_to_one_hot(indices: np.ndarray) -> Tensor:

# Convert a single sequence
if indices.ndim == 1:
return one_hot(torch.LongTensor(indices.copy()), num_classes=5)[:, :4].T.type(
one_hot = one_hot(torch.LongTensor(indices.copy()), num_classes=5)[
:, :4
].T.type(
torch.float32
) # Output shape: 4, L
return one_hot.unsqueeze(0) if add_batch_axis else one_hot

# Convert multiple sequences
else:
Expand Down Expand Up @@ -367,6 +373,7 @@ def convert_input_type(
output_type: str = "indices",
genome: Optional[str] = None,
add_batch_axis: bool = False,
input_type: Optional[str] = None,
) -> Union[pd.DataFrame, str, List[str], np.ndarray, Tensor]:
"""
Convert input DNA sequence data into the desired format.
Expand All @@ -378,6 +385,7 @@ def convert_input_type(
add_batch_axis: If True, a batch axis will be included in the output for single
sequences. If False, the output for a single sequence will be a 2-dimensional
tensor.
input_type: Format of the input sequence (optional)

Returns:
The converted DNA sequence(s) in the desired format.
Expand All @@ -387,7 +395,7 @@ def convert_input_type(

"""
# Determine input type
input_type = get_input_type(inputs)
input_type = input_type or get_input_type(inputs)

# If no conversion needed, return inputs as is
if input_type == output_type:
Expand Down Expand Up @@ -416,7 +424,7 @@ def convert_input_type(
# Convert indices
if input_type == "indices":
if output_type == "one_hot":
return indices_to_one_hot(inputs)
return indices_to_one_hot(inputs, add_batch_axis=add_batch_axis)
elif output_type == "strings":
return indices_to_strings(inputs)

Expand Down
30 changes: 11 additions & 19 deletions src/grelu/sequence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def reverse_complement(
def dinuc_shuffle(
seqs: Union[pd.DataFrame, np.ndarray, List[str]],
n_shuffles: int = 1,
start=0,
end=-1,
input_type: Optional[str] = None,
seed: Optional[int] = None,
genome: Optional[str] = None,
Expand All @@ -393,32 +395,22 @@ def dinuc_shuffle(
Returns:
Shuffled sequences in the same format as the input
"""
import torch
from bpnetlite.attributions import dinucleotide_shuffle
from einops import rearrange
from tangermeme.ersatz import dinucleotide_shuffle

# Input format
input_type = input_type or get_input_type(seqs)

# One-hot encode
seqs = convert_input_type(seqs, "one_hot", genome=genome) # N, 4, L
seqs = convert_input_type(
seqs, "one_hot", genome=genome, add_batch_axis=True
) # B, 4, L

# Shuffle sequences as many times as required
if n_shuffles > 0:
if seqs.ndim == 2: # 4, L
shuf_seqs = dinucleotide_shuffle(
seqs, n_shuffles=n_shuffles, random_state=seed
) # N, 4, L
else:
shuf_seqs = torch.vstack(
[
dinucleotide_shuffle(seq, n_shuffles=n_shuffles, random_state=seed)
for seq in seqs
]
) # B, 4, L

# If no shuffling is required, return the original sequences
else:
return seqs
shuf_seqs = dinucleotide_shuffle(
X=seqs, start=start, end=end, n=n_shuffles, random_state=seed, verbose=False
) # B, n, 4, L
shuf_seqs = rearrange(shuf_seqs, "b n t l -> (b n) t l")

return convert_input_type(shuf_seqs, input_type)

Expand Down
32 changes: 20 additions & 12 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,34 +797,37 @@ def test_ism_dataset():
def test_marginalize_dataset_variants():
# Marginalize variants
ds = VariantMarginalizeDataset(
variants=variants, genome="hg38", seq_len=6, n_shuffles=2, seed=0
variants=variants, genome="hg38", seq_len=12, n_shuffles=2, seed=0
)
assert (
(ds.n_shuffles == 2)
and (ds.seq_len == 6)
and (ds.seq_len == 12)
and (ds.n_seqs == 2)
and (ds.ref.shape == (2, 1))
and (ds.alt.shape == (2, 1))
and (len(ds) == 8)
and (ds.n_augmented == 2)
and (np.allclose(ds.ref, np.array([[2], [2]])))
and (np.allclose(ds.alt, np.array([[0], [0]])))
)
assert convert_input_type(ds.seqs, "strings") == ["CATACGTGAGGC", "AGGAGGCCAAAG"]
xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))]
assert xs == [
"ACGTGA",
"ACATGA",
"ACGTGA",
"ACATGA",
"AGGCCA",
"AGACCA",
"AGGCCA",
"AGACCA",
"CACGTGTGAGGC",
"CACGTATGAGGC",
"CACGAGAGTGGC",
"CACGAAAGTGGC",
"AAGGGGGCCAAG",
"AAGGGAGCCAAG",
"AAGAGGGCCAAG",
"AAGAGAGCCAAG",
]


def test_marginalize_dataset_motifs():
# Marginalize motifs
ds = PatternMarginalizeDataset(
seqs=["ACCTACACT"], patterns=["AAA"], n_shuffles=2, seed=0
seqs=["AAGACATACAACGCGCGCTAACATAGCAAC"], patterns=["AAA"], n_shuffles=2, seed=0
)
assert (
(ds.n_shuffles == 2)
Expand All @@ -836,7 +839,12 @@ def test_marginalize_dataset_motifs():
)

xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))]
assert xs == ["ACACCGACG", "ACAAAAACG", "ACACGACCG", "ACAAAACCG"]
assert xs == [
"ACGCATACGAGCGCTACAGCAACATAAAAC",
"ACGCATACGAGCGAAACAGCAACATAAAAC",
"ACTAACAACAGCACGCGCGATATAAGCAAC",
"ACTAACAACAGCAAAAGCGATATAAGCAAC",
]


# Test Motif scanning dataset
Expand Down
11 changes: 8 additions & 3 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_trim_pwm():


def test_marginalize_patterns():
seqs = ["ACTGT", "GATCC"]
seqs = ["CATACGTGAGGC", "AGGAGGCCAAAG"]
preds_before, preds_after = marginalize_patterns(
model,
patterns=["A"],
Expand All @@ -81,9 +81,14 @@ def test_marginalize_patterns():
compare_func=None,
)
assert preds_before.shape == (2, 3, 1)
assert np.allclose(preds_before.squeeze(), [[0.4, 0.4, 0.4], [0, 0, 0]])
assert np.allclose(
preds_before.squeeze(), [[0.5, 0.5, 0.5], [1.3333334, 1.3333334, 1.3333334]]
)
assert preds_after.shape == (2, 3, 1)
assert np.allclose(preds_after.squeeze(), [[1.2, 1.2, 1.2], [0.8, 0.8, 0.8]])
assert np.allclose(
preds_after.squeeze(),
[[0.5, 0.8333333, 0.8333333], [1.3333334, 1.6666666, 1.6666666]],
)


def test_ISM_predict():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def test_seq_formatting():

# indices to one-hot
assert torch.allclose(convert_input_type(indices, "one_hot"), batch)
assert torch.allclose(
convert_input_type(indices[0], "one_hot", add_batch_axis=True), batch[[0]]
)


# Test Metrics functions
Expand Down
Loading