diff --git a/setup.cfg b/setup.cfg index 4d40596..0a58f80 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,6 +71,7 @@ install_requires = logomaker >= 0.8 pyBigWig ledidi + tangermeme pygenomeviz diff --git a/src/grelu/interpret/score.py b/src/grelu/interpret/score.py index cebc3ed..4695ae3 100644 --- a/src/grelu/interpret/score.py +++ b/src/grelu/interpret/score.py @@ -8,7 +8,8 @@ import numpy as np import pandas as pd import torch -from captum.attr import DeepLiftShap, InputXGradient, IntegratedGradients +from captum.attr import InputXGradient, IntegratedGradients +from tangermeme.deep_lift_shap import deep_lift_shap from torch import Tensor from grelu.sequence.format import convert_input_type @@ -156,46 +157,35 @@ def get_attributions( # Initialize the attributer if method == "deepshap": - from bpnetlite.attributions import ( - dinucleotide_shuffle, - hypothetical_attributions, - ) + attributions = deep_lift_shap( + model, + X=seqs, + n_shuffles=n_shuffles, + hypothetical=hypothetical, + device=device, + random_state=seed, + ).numpy(force=True) - attributer = DeepLiftShap(model.to(device)) - elif method == "integratedgradients": - attributer = IntegratedGradients(model.to(device)) - elif method == "inputxgradient": - attributer = InputXGradient(model.to(device)) else: - raise NotImplementedError + if method == "integratedgradients": + attributer = IntegratedGradients(model.to(device)) + elif method == "inputxgradient": + attributer = InputXGradient(model.to(device)) + else: + raise NotImplementedError - # Calculate attributions for each sequence - with torch.no_grad(): - for i in range(len(seqs)): - X_ = seqs[i : i + 1].to(device) # 1, 4, L - - if method == "deepshap": - reference = dinucleotide_shuffle( - X_[0].cpu(), n_shuffles=n_shuffles, random_state=seed - ).to(device) - - attr = attributer.attribute( - X_, - reference, - target=0, - custom_attribution_func=hypothetical_attributions, - ) - if not hypothetical: - attr = attr * X_ - else: + # Calculate attributions for each sequence + with torch.no_grad(): + for i in range(len(seqs)): + X_ = seqs[i : i + 1].to(device) # 1, 4, L attr = attributer.attribute(X_) + attributions.append(attr.cpu().numpy()) - attributions.append(attr.cpu().numpy()) + attributions = np.vstack(attributions) # Remove transform model.reset_transform() - - return np.vstack(attributions) # N, 4, L + return attributions # N, 4, L def run_modisco( diff --git a/tests/test_interpret.py b/tests/test_interpret.py index e8bc4a6..2ec8fb9 100644 --- a/tests/test_interpret.py +++ b/tests/test_interpret.py @@ -11,6 +11,7 @@ ) from grelu.interpret.score import ISM_predict, get_attention_scores, get_attributions from grelu.lightning import LightningModel +from grelu.sequence.utils import generate_random_sequences cwd = os.path.realpath(os.path.dirname(__file__)) meme_file = os.path.join(cwd, "files", "test.meme") @@ -95,10 +96,11 @@ def test_ISM_predict(): def test_get_attributions(): - attrs = get_attributions(model, "GGG", hypothetical=False, n_shuffles=10) - assert attrs.shape == (1, 4, 3) - attrs = get_attributions(model, "ACG", hypothetical=True, n_shuffles=10) - assert attrs.shape == (1, 4, 3) + seq = generate_random_sequences(n=1, seq_len=50, seed=0, output_format="strings")[0] + attrs = get_attributions(model, seq, hypothetical=False, n_shuffles=10) + assert attrs.shape == (1, 4, 50) + attrs = get_attributions(model, seq, hypothetical=True, n_shuffles=10) + assert attrs.shape == (1, 4, 50) def test_get_attention_scores():