Skip to content

Commit

Permalink
prevent deepshap from being used with enformer models
Browse files Browse the repository at this point in the history
  • Loading branch information
avantikalal committed Jun 18, 2024
1 parent 623fee8 commit e714ed4
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 299 deletions.
332 changes: 164 additions & 168 deletions docs/tutorials/4_design.ipynb

Large diffs are not rendered by default.

282 changes: 161 additions & 121 deletions docs/tutorials/5_variant.ipynb

Large diffs are not rendered by default.

31 changes: 21 additions & 10 deletions src/grelu/interpret/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tangermeme.deep_lift_shap import deep_lift_shap
from torch import Tensor

from grelu.model.models import EnformerModel, EnformerPretrainedModel
from grelu.sequence.format import convert_input_type


Expand Down Expand Up @@ -122,6 +123,7 @@ def get_attributions(
hypothetical: bool = False,
n_shuffles: int = 20,
seed=None,
**kwargs,
) -> np.array:
"""
Get per-nucleotide importance scores for sequences using Captum.
Expand All @@ -133,10 +135,11 @@ def get_attributions(
prediction_transform: A module to transform the model output
devices: Indices of the devices to use for inference
method: One of "deepshap", "saliency", "inputxgradient" or "integratedgradients"
hypothetical: whether to calculate hypothetical importance scores
set to True to obtain input for tf-modisco, False otherwise
hypothetical: whether to calculate hypothetical importance scores.
Set this to True to obtain input for tf-modisco, False otherwise
n_shuffles: Number of times to dinucleotide shuffle sequence
seed: Random seed
**kwargs: Additional arguments to pass to tangermeme.deep_lift_shap.deep_lift_shap
Returns:
Per-nucleotide importance scores as numpy array of shape (B, 4, L).
Expand All @@ -157,14 +160,22 @@ def get_attributions(

# Initialize the attributer
if method == "deepshap":
attributions = deep_lift_shap(
model,
X=seqs,
n_shuffles=n_shuffles,
hypothetical=hypothetical,
device=device,
random_state=seed,
).numpy(force=True)
if isinstance(model.model, EnformerModel) or isinstance(
model.model, EnformerPretrainedModel
):
raise NotImplementedError(
"DeepShap currently cannot be applied to Enformer models."
)
else:
attributions = deep_lift_shap(
model,
X=seqs,
n_shuffles=n_shuffles,
hypothetical=hypothetical,
device=device,
random_state=seed,
**kwargs,
).numpy(force=True)

else:
if method == "integratedgradients":
Expand Down

0 comments on commit e714ed4

Please sign in to comment.