Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tangermeme implementation of DeepLift/SHAP #5

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ install_requires =
logomaker >= 0.8
pyBigWig
ledidi
tangermeme
pygenomeviz


Expand Down
56 changes: 23 additions & 33 deletions src/grelu/interpret/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
import pandas as pd
import torch
from captum.attr import DeepLiftShap, InputXGradient, IntegratedGradients
from captum.attr import InputXGradient, IntegratedGradients
from tangermeme.deep_lift_shap import deep_lift_shap
from torch import Tensor

from grelu.sequence.format import convert_input_type
Expand Down Expand Up @@ -156,46 +157,35 @@ def get_attributions(

# Initialize the attributer
if method == "deepshap":
from bpnetlite.attributions import (
dinucleotide_shuffle,
hypothetical_attributions,
)
attributions = deep_lift_shap(
model,
X=seqs,
n_shuffles=n_shuffles,
hypothetical=hypothetical,
device=device,
random_state=seed,
).numpy(force=True)

attributer = DeepLiftShap(model.to(device))
elif method == "integratedgradients":
attributer = IntegratedGradients(model.to(device))
elif method == "inputxgradient":
attributer = InputXGradient(model.to(device))
else:
raise NotImplementedError
if method == "integratedgradients":
attributer = IntegratedGradients(model.to(device))
elif method == "inputxgradient":
attributer = InputXGradient(model.to(device))
else:
raise NotImplementedError

# Calculate attributions for each sequence
with torch.no_grad():
for i in range(len(seqs)):
X_ = seqs[i : i + 1].to(device) # 1, 4, L

if method == "deepshap":
reference = dinucleotide_shuffle(
X_[0].cpu(), n_shuffles=n_shuffles, random_state=seed
).to(device)

attr = attributer.attribute(
X_,
reference,
target=0,
custom_attribution_func=hypothetical_attributions,
)
if not hypothetical:
attr = attr * X_
else:
# Calculate attributions for each sequence
with torch.no_grad():
for i in range(len(seqs)):
X_ = seqs[i : i + 1].to(device) # 1, 4, L
attr = attributer.attribute(X_)
attributions.append(attr.cpu().numpy())

attributions.append(attr.cpu().numpy())
attributions = np.vstack(attributions)

# Remove transform
model.reset_transform()

return np.vstack(attributions) # N, 4, L
return attributions # N, 4, L


def run_modisco(
Expand Down
10 changes: 6 additions & 4 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from grelu.interpret.score import ISM_predict, get_attention_scores, get_attributions
from grelu.lightning import LightningModel
from grelu.sequence.utils import generate_random_sequences

cwd = os.path.realpath(os.path.dirname(__file__))
meme_file = os.path.join(cwd, "files", "test.meme")
Expand Down Expand Up @@ -95,10 +96,11 @@ def test_ISM_predict():


def test_get_attributions():
attrs = get_attributions(model, "GGG", hypothetical=False, n_shuffles=10)
assert attrs.shape == (1, 4, 3)
attrs = get_attributions(model, "ACG", hypothetical=True, n_shuffles=10)
assert attrs.shape == (1, 4, 3)
seq = generate_random_sequences(n=1, seq_len=50, seed=0, output_format="strings")[0]
attrs = get_attributions(model, seq, hypothetical=False, n_shuffles=10)
assert attrs.shape == (1, 4, 50)
attrs = get_attributions(model, seq, hypothetical=True, n_shuffles=10)
assert attrs.shape == (1, 4, 50)


def test_get_attention_scores():
Expand Down
Loading