-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #506 from dianna-ai/478-rise-for-time-series
478 rise for time series
- Loading branch information
Showing
9 changed files
with
2,026 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
from tqdm import tqdm | ||
from dianna import utils | ||
from dianna.utils.maskers import generate_masks | ||
from dianna.utils.maskers import mask_data | ||
|
||
|
||
def _make_predictions(input_data, runner, batch_size): | ||
"""Process the input_data with the model runner in batches and return the predictions.""" | ||
number_of_masks = input_data.shape[0] | ||
batch_predictions = [] | ||
for i in tqdm(range(0, number_of_masks, batch_size), desc='Explaining'): | ||
batch_predictions.append(runner(input_data[i:i + batch_size])) | ||
return np.concatenate(batch_predictions) | ||
|
||
|
||
# Duplicate code from rise.py: | ||
def normalize(saliency, n_masks, p_keep): | ||
"""Normalizes salience by number of masks and keep probability.""" | ||
return saliency / n_masks / p_keep | ||
|
||
|
||
class RISETimeseries: | ||
"""RISE implementation for timeseries adapted from the image version of RISE.""" | ||
|
||
def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5, | ||
preprocess_function=None): | ||
"""RISE initializer. | ||
Args: | ||
n_masks (int): Number of masks to generate. | ||
feature_res (int): Resolution of features in masks. | ||
p_keep (float): Fraction of input data to keep in each mask (Default: auto-tune this value). | ||
preprocess_function (callable, optional): Function to preprocess input data with | ||
""" | ||
self.n_masks = n_masks | ||
self.feature_res = feature_res | ||
self.p_keep = p_keep | ||
self.preprocess_function = preprocess_function | ||
self.masks = None | ||
self.predictions = None | ||
|
||
def explain(self, model_or_function, input_timeseries, labels, batch_size=100, mask_type='mean'): | ||
"""Runs the RISE explainer on images. | ||
The model will be called with masked timeseries, | ||
with a shape defined by `batch_size` and the shape of `input_data`. | ||
Args: | ||
model_or_function (callable or str): The function that runs the model to be explained _or_ | ||
the path to a ONNX model on disk. | ||
input_timeseries (np.ndarray): Input timeseries data to be explained | ||
batch_size (int): Batch size to use for running the model. | ||
labels (Iterable(int)): Labels to be explained | ||
mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries) | ||
Returns: | ||
Explanation heatmap for each class (np.ndarray). | ||
""" | ||
runner = utils.get_function(model_or_function, preprocess_function=self.preprocess_function) | ||
self.masks = generate_masks(input_timeseries, number_of_masks=self.n_masks, p_keep=self.p_keep) | ||
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type) | ||
|
||
self.predictions = _make_predictions(masked, runner, batch_size) | ||
n_labels = self.predictions.shape[1] | ||
|
||
saliency = self.predictions.T.dot(self.masks.reshape(self.n_masks, -1)).reshape(n_labels, | ||
*input_timeseries.shape) | ||
selected_saliency = saliency[labels] | ||
return normalize(selected_saliency, self.n_masks, self.p_keep) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
import dianna | ||
from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day | ||
from tests.methods.time_series_test_case import input_train_mean | ||
from tests.methods.time_series_test_case import run_expert_model | ||
from tests.utils import run_model | ||
|
||
|
||
def test_rise_timeseries_correct_output_shape(): | ||
"""Test if rise runs and outputs the correct shape given some data and a model function.""" | ||
input_data = np.random.random((10, 1)) | ||
axis_labels = ['t', 'channels'] | ||
labels = [1] | ||
|
||
heatmaps = dianna.explain_timeseries(run_model, input_data, "RISE", labels, axis_labels=axis_labels, | ||
n_masks=200, p_keep=.5) | ||
|
||
assert heatmaps.shape == (len(labels), *input_data.shape) | ||
|
||
|
||
def test_rise_timeseries_with_expert_model_for_correct_max_and_min(): | ||
"""Test if RISE highlights the correct areas for this artificial example.""" | ||
hot_day_index = 6 | ||
cold_day_index = 12 | ||
temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index) | ||
|
||
summer_explanation, winter_explanation = dianna.explain_timeseries(run_expert_model, | ||
timeseries_data=temperature_timeseries, | ||
method='rise', | ||
labels=[0, 1], | ||
p_keep=0.1, n_masks=10000, | ||
mask_type=input_train_mean) | ||
|
||
assert np.argmax(summer_explanation) == hot_day_index | ||
assert np.argmin(summer_explanation) == cold_day_index | ||
assert np.argmax(winter_explanation) == cold_day_index | ||
assert np.argmin(winter_explanation) == hot_day_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
|
||
|
||
"""Test case for timeseries xai methods. | ||
This test case is designed to show if the xai methods could provide reasonable results. | ||
In this test case, every test instance is a 28 days by 1 channel array indicating the max temp on a day. | ||
""" | ||
|
||
|
||
def input_train_mean(_data): | ||
"""Return overall mean temperature of 14.""" | ||
return 14 | ||
|
||
|
||
def average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index): | ||
"""Creates a temperature time series of all 14s and a single cold (-2) and hot (30) day.""" | ||
temperature_timeseries = np.expand_dims(np.zeros(28), axis=1) + 14 | ||
temperature_timeseries[hot_day_index] = 30 | ||
temperature_timeseries[cold_day_index] = -2 | ||
return temperature_timeseries | ||
|
||
|
||
def run_expert_model(data): | ||
"""A simple model that classifies a batch of timeseries. | ||
All instances with an average above 14 are classified as summer (0) and the rest as winter (1). | ||
""" | ||
# Make actual decision | ||
is_summer = np.mean(np.mean(data, axis=1), axis=1) > 14 | ||
|
||
# Create the correct output format | ||
number_of_classes = 2 | ||
number_of_instances = data.shape[0] | ||
result = np.zeros((number_of_instances, number_of_classes)) | ||
result[is_summer] = [1.0, 0.0] | ||
result[~is_summer] = [0.0, 1.0] | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Oops, something went wrong.