diff --git a/setup.cfg b/setup.cfg index f7bad31..55ecb8c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,7 +67,6 @@ install_requires = genomepy bioframe >= 0.4 captum == 0.5.0 - bpnet-lite == 0.5.7 logomaker >= 0.8 pyBigWig ledidi diff --git a/src/grelu/data/preprocess.py b/src/grelu/data/preprocess.py index b3d9a7d..271b2bb 100644 --- a/src/grelu/data/preprocess.py +++ b/src/grelu/data/preprocess.py @@ -3,7 +3,6 @@ """ import os import subprocess -import tempfile from typing import Callable, List, Optional, Union import bioframe as bf @@ -479,7 +478,6 @@ def get_gc_matched_intervals( genome: str, binwidth: float = 0.1, chroms: str = "autosomes", - gc_bw_file: str = None, blacklist: str = "hg38", seed: Optional[int] = None, ) -> pd.DataFrame: @@ -491,15 +489,13 @@ def get_gc_matched_intervals( genome: Name of the genome corresponding to intervals binwidth: Resolution of GC content chroms: Chromosomes to search for matched intervals - gc_bw_file: Path to a bigWig file of genomewide GC content. - If None, will be created. blacklist: Blacklist file of regions to exclude seed: Random seed Returns: A pandas dataframe containing GC-matched negative intervals. """ - from bpnetlite.negatives import calculate_gc_genomewide, extract_matching_loci + from tangermeme.match import extract_matching_loci from grelu.io.genome import get_genome from grelu.sequence.utils import get_unique_length @@ -510,25 +506,17 @@ def get_gc_matched_intervals( # Get seq_len seq_len = get_unique_length(intervals) - # Get bigWig file of GC content - if gc_bw_file is None: - gc_bw_file = "gc_{}_{}.bw".format(genome.name, seq_len) - print("Calculating GC content genomewide and saving to {}".format(gc_bw_file)) - calculate_gc_genomewide( - fasta=genome.genome_file, - bigwig=gc_bw_file, - width=seq_len, - include_chroms=chroms, - verbose=True, - ) - print("Extracting matching intervals") - _, tmpfile = tempfile.mkstemp() - intervals.iloc[:, :3].to_csv(tmpfile, sep="\t", index=False, header=False) matched_loci = extract_matching_loci( - bed=tmpfile, bigwig=gc_bw_file, width=seq_len, bin_width=binwidth, verbose=True + intervals, + fasta=genome.genome_file, + in_window=seq_len, + gc_bin_width=binwidth, + chroms=chroms, + verbose=False, + random_state=seed, ) - os.remove(tmpfile) + print("Filtering blacklist") if blacklist is not None: matched_loci = filter_blacklist(matched_loci, blacklist) diff --git a/src/grelu/interpret/motifs.py b/src/grelu/interpret/motifs.py index 20e1a26..5e8ed44 100644 --- a/src/grelu/interpret/motifs.py +++ b/src/grelu/interpret/motifs.py @@ -225,6 +225,7 @@ def marginalize_patterns( genome=genome, rc=rc, n_shuffles=n_shuffles, + seed=seed, ) # Get predictions on the sequences before motif insertion diff --git a/src/grelu/sequence/format.py b/src/grelu/sequence/format.py index 5ba944d..ec6e82a 100644 --- a/src/grelu/sequence/format.py +++ b/src/grelu/sequence/format.py @@ -256,12 +256,15 @@ def strings_to_indices( ).astype(np.int8) -def indices_to_one_hot(indices: np.ndarray) -> Tensor: +def indices_to_one_hot(indices: np.ndarray, add_batch_axis: bool = False) -> Tensor: """ Convert integer-encoded DNA sequences to one-hot encoded format. Args: indices: Integer-encoded DNA sequences. + add_batch_axis: If True, a batch axis will be included in the output for single + sequences. If False, the output for a single sequence will be a 2-dimensional + tensor. Returns: The one-hot encoded sequences. @@ -271,9 +274,12 @@ def indices_to_one_hot(indices: np.ndarray) -> Tensor: # Convert a single sequence if indices.ndim == 1: - return one_hot(torch.LongTensor(indices.copy()), num_classes=5)[:, :4].T.type( + one_hot = one_hot(torch.LongTensor(indices.copy()), num_classes=5)[ + :, :4 + ].T.type( torch.float32 ) # Output shape: 4, L + return one_hot.unsqueeze(0) if add_batch_axis else one_hot # Convert multiple sequences else: @@ -367,6 +373,7 @@ def convert_input_type( output_type: str = "indices", genome: Optional[str] = None, add_batch_axis: bool = False, + input_type: Optional[str] = None, ) -> Union[pd.DataFrame, str, List[str], np.ndarray, Tensor]: """ Convert input DNA sequence data into the desired format. @@ -378,6 +385,7 @@ def convert_input_type( add_batch_axis: If True, a batch axis will be included in the output for single sequences. If False, the output for a single sequence will be a 2-dimensional tensor. + input_type: Format of the input sequence (optional) Returns: The converted DNA sequence(s) in the desired format. @@ -387,7 +395,7 @@ def convert_input_type( """ # Determine input type - input_type = get_input_type(inputs) + input_type = input_type or get_input_type(inputs) # If no conversion needed, return inputs as is if input_type == output_type: @@ -416,7 +424,7 @@ def convert_input_type( # Convert indices if input_type == "indices": if output_type == "one_hot": - return indices_to_one_hot(inputs) + return indices_to_one_hot(inputs, add_batch_axis=add_batch_axis) elif output_type == "strings": return indices_to_strings(inputs) diff --git a/src/grelu/sequence/utils.py b/src/grelu/sequence/utils.py index 36d925f..65672f4 100644 --- a/src/grelu/sequence/utils.py +++ b/src/grelu/sequence/utils.py @@ -375,6 +375,8 @@ def reverse_complement( def dinuc_shuffle( seqs: Union[pd.DataFrame, np.ndarray, List[str]], n_shuffles: int = 1, + start=0, + end=-1, input_type: Optional[str] = None, seed: Optional[int] = None, genome: Optional[str] = None, @@ -393,32 +395,22 @@ def dinuc_shuffle( Returns: Shuffled sequences in the same format as the input """ - import torch - from bpnetlite.attributions import dinucleotide_shuffle + from einops import rearrange + from tangermeme.ersatz import dinucleotide_shuffle # Input format input_type = input_type or get_input_type(seqs) # One-hot encode - seqs = convert_input_type(seqs, "one_hot", genome=genome) # N, 4, L + seqs = convert_input_type( + seqs, "one_hot", genome=genome, add_batch_axis=True + ) # B, 4, L # Shuffle sequences as many times as required - if n_shuffles > 0: - if seqs.ndim == 2: # 4, L - shuf_seqs = dinucleotide_shuffle( - seqs, n_shuffles=n_shuffles, random_state=seed - ) # N, 4, L - else: - shuf_seqs = torch.vstack( - [ - dinucleotide_shuffle(seq, n_shuffles=n_shuffles, random_state=seed) - for seq in seqs - ] - ) # B, 4, L - - # If no shuffling is required, return the original sequences - else: - return seqs + shuf_seqs = dinucleotide_shuffle( + X=seqs, start=start, end=end, n=n_shuffles, random_state=seed, verbose=False + ) # B, n, 4, L + shuf_seqs = rearrange(shuf_seqs, "b n t l -> (b n) t l") return convert_input_type(shuf_seqs, input_type) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 42d2c4b..1075441 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -797,34 +797,37 @@ def test_ism_dataset(): def test_marginalize_dataset_variants(): # Marginalize variants ds = VariantMarginalizeDataset( - variants=variants, genome="hg38", seq_len=6, n_shuffles=2, seed=0 + variants=variants, genome="hg38", seq_len=12, n_shuffles=2, seed=0 ) assert ( (ds.n_shuffles == 2) - and (ds.seq_len == 6) + and (ds.seq_len == 12) and (ds.n_seqs == 2) and (ds.ref.shape == (2, 1)) and (ds.alt.shape == (2, 1)) and (len(ds) == 8) and (ds.n_augmented == 2) + and (np.allclose(ds.ref, np.array([[2], [2]]))) + and (np.allclose(ds.alt, np.array([[0], [0]]))) ) + assert convert_input_type(ds.seqs, "strings") == ["CATACGTGAGGC", "AGGAGGCCAAAG"] xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))] assert xs == [ - "ACGTGA", - "ACATGA", - "ACGTGA", - "ACATGA", - "AGGCCA", - "AGACCA", - "AGGCCA", - "AGACCA", + "CACGTGTGAGGC", + "CACGTATGAGGC", + "CACGAGAGTGGC", + "CACGAAAGTGGC", + "AAGGGGGCCAAG", + "AAGGGAGCCAAG", + "AAGAGGGCCAAG", + "AAGAGAGCCAAG", ] def test_marginalize_dataset_motifs(): # Marginalize motifs ds = PatternMarginalizeDataset( - seqs=["ACCTACACT"], patterns=["AAA"], n_shuffles=2, seed=0 + seqs=["AAGACATACAACGCGCGCTAACATAGCAAC"], patterns=["AAA"], n_shuffles=2, seed=0 ) assert ( (ds.n_shuffles == 2) @@ -836,7 +839,12 @@ def test_marginalize_dataset_motifs(): ) xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))] - assert xs == ["ACACCGACG", "ACAAAAACG", "ACACGACCG", "ACAAAACCG"] + assert xs == [ + "ACGCATACGAGCGCTACAGCAACATAAAAC", + "ACGCATACGAGCGAAACAGCAACATAAAAC", + "ACTAACAACAGCACGCGCGATATAAGCAAC", + "ACTAACAACAGCAAAAGCGATATAAGCAAC", + ] # Test Motif scanning dataset diff --git a/tests/test_interpret.py b/tests/test_interpret.py index 2ec8fb9..ee73e9b 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -71,7 +71,7 @@ def test_trim_pwm(): def test_marginalize_patterns(): - seqs = ["ACTGT", "GATCC"] + seqs = ["CATACGTGAGGC", "AGGAGGCCAAAG"] preds_before, preds_after = marginalize_patterns( model, patterns=["A"], @@ -81,9 +81,14 @@ def test_marginalize_patterns(): compare_func=None, ) assert preds_before.shape == (2, 3, 1) - assert np.allclose(preds_before.squeeze(), [[0.4, 0.4, 0.4], [0, 0, 0]]) + assert np.allclose( + preds_before.squeeze(), [[0.5, 0.5, 0.5], [1.3333334, 1.3333334, 1.3333334]] + ) assert preds_after.shape == (2, 3, 1) - assert np.allclose(preds_after.squeeze(), [[1.2, 1.2, 1.2], [0.8, 0.8, 0.8]]) + assert np.allclose( + preds_after.squeeze(), + [[0.5, 0.8333333, 0.8333333], [1.3333334, 1.6666666, 1.6666666]], + ) def test_ISM_predict(): diff --git a/tests/test_sequence.py b/tests/test_sequence.py index d0ac0b7..4e12c42 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -125,6 +125,9 @@ def test_seq_formatting(): # indices to one-hot assert torch.allclose(convert_input_type(indices, "one_hot"), batch) + assert torch.allclose( + convert_input_type(indices[0], "one_hot", add_batch_axis=True), batch[[0]] + ) # Test Metrics functions