Skip to content

Commit

Permalink
Merge pull request #506 from dianna-ai/478-rise-for-time-series
Browse files Browse the repository at this point in the history
478 rise for time series
  • Loading branch information
Yang authored Mar 30, 2023
2 parents c755932 + ee2c4a5 commit 7987e51
Show file tree
Hide file tree
Showing 9 changed files with 2,026 additions and 15 deletions.
23 changes: 23 additions & 0 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
__version__ = "0.7.0"


def explain_timeseries(model_or_function, timeseries_data, method, labels, **kwargs):
"""Explain timeseries data given a model and a chosen method.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
timeseries_data (np.ndarray): Timeseries data to be explained
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
kwargs: key word arguments
Returns:
One heatmap per class.
"""
explainer = _get_explainer(method, kwargs, modality="Timeseries")
explain_image_kwargs = utils.get_kwargs_applicable_to_function(explainer.explain, kwargs)
return explainer.explain(model_or_function,
timeseries_data,
labels,
**explain_image_kwargs)


def explain_image(model_or_function, input_data, method, labels, **kwargs):
"""Explain an image (input_data) given a model and a chosen method.
Expand Down
3 changes: 3 additions & 0 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from tqdm import tqdm
from dianna import utils

# To Do: remove this import when the method for different input type is splitted
from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import


def normalize(saliency, n_masks, p_keep):
"""Normalizes salience by number of masks and keep probability."""
Expand Down
70 changes: 70 additions & 0 deletions dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
from tqdm import tqdm
from dianna import utils
from dianna.utils.maskers import generate_masks
from dianna.utils.maskers import mask_data


def _make_predictions(input_data, runner, batch_size):
"""Process the input_data with the model runner in batches and return the predictions."""
number_of_masks = input_data.shape[0]
batch_predictions = []
for i in tqdm(range(0, number_of_masks, batch_size), desc='Explaining'):
batch_predictions.append(runner(input_data[i:i + batch_size]))
return np.concatenate(batch_predictions)


# Duplicate code from rise.py:
def normalize(saliency, n_masks, p_keep):
"""Normalizes salience by number of masks and keep probability."""
return saliency / n_masks / p_keep


class RISETimeseries:
"""RISE implementation for timeseries adapted from the image version of RISE."""

def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5,
preprocess_function=None):
"""RISE initializer.
Args:
n_masks (int): Number of masks to generate.
feature_res (int): Resolution of features in masks.
p_keep (float): Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function (callable, optional): Function to preprocess input data with
"""
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.predictions = None

def explain(self, model_or_function, input_timeseries, labels, batch_size=100, mask_type='mean'):
"""Runs the RISE explainer on images.
The model will be called with masked timeseries,
with a shape defined by `batch_size` and the shape of `input_data`.
Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_timeseries (np.ndarray): Input timeseries data to be explained
batch_size (int): Batch size to use for running the model.
labels (Iterable(int)): Labels to be explained
mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries)
Returns:
Explanation heatmap for each class (np.ndarray).
"""
runner = utils.get_function(model_or_function, preprocess_function=self.preprocess_function)
self.masks = generate_masks(input_timeseries, number_of_masks=self.n_masks, p_keep=self.p_keep)
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type)

self.predictions = _make_predictions(masked, runner, batch_size)
n_labels = self.predictions.shape[1]

saliency = self.predictions.T.dot(self.masks.reshape(self.n_masks, -1)).reshape(n_labels,
*input_timeseries.shape)
selected_saliency = saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
18 changes: 10 additions & 8 deletions dianna/utils/maskers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Union
import numpy as np


Expand All @@ -17,35 +18,36 @@ def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0
number_of_steps_masked = _determine_number_of_steps_masked(p_keep, series_length)

masked_data_shape = [number_of_masks] + list(input_data.shape)
masks = np.zeros(masked_data_shape, dtype=np.bool)
masks = np.ones(masked_data_shape, dtype=np.bool)
for i in range(number_of_masks):
steps_to_mask = np.random.choice(series_length, number_of_steps_masked, False)
masked_value = 1
masks[i, steps_to_mask] = masked_value
masks[i, steps_to_mask] = False
return masks


def mask_data(data, masks, mask_type='mean'):
def mask_data(data: np.array, masks: np.array, mask_type: Union[object, str]):
"""Mask data given using a set of masks.
Args:
data: ?
data: Input data.
masks: an array with shape [number_of_masks] + data.shape
mask_type: ?
mask_type: Masking strategy.
Returns:
Single array containing all masked input where the first dimension represents the batch.
"""
number_of_masks = masks.shape[0]
input_data_batch = np.repeat(np.expand_dims(data, 0), number_of_masks, axis=0)
result = np.empty(input_data_batch.shape)
result[~masks] = input_data_batch[~masks]
result[masks] = _get_mask_value(data, mask_type)
result[masks] = input_data_batch[masks]
result[~masks] = _get_mask_value(data, mask_type)
return result


def _get_mask_value(data: np.array, mask_type: str) -> int:
"""Calculates a masking value of the given type for the data."""
if callable(mask_type):
return mask_type(data)
if mask_type == 'mean':
return np.mean(data)
raise ValueError(f'Unknown mask_type selected: {mask_type}')
Expand Down
37 changes: 37 additions & 0 deletions tests/methods/test_rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import dianna
from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day
from tests.methods.time_series_test_case import input_train_mean
from tests.methods.time_series_test_case import run_expert_model
from tests.utils import run_model


def test_rise_timeseries_correct_output_shape():
"""Test if rise runs and outputs the correct shape given some data and a model function."""
input_data = np.random.random((10, 1))
axis_labels = ['t', 'channels']
labels = [1]

heatmaps = dianna.explain_timeseries(run_model, input_data, "RISE", labels, axis_labels=axis_labels,
n_masks=200, p_keep=.5)

assert heatmaps.shape == (len(labels), *input_data.shape)


def test_rise_timeseries_with_expert_model_for_correct_max_and_min():
"""Test if RISE highlights the correct areas for this artificial example."""
hot_day_index = 6
cold_day_index = 12
temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index)

summer_explanation, winter_explanation = dianna.explain_timeseries(run_expert_model,
timeseries_data=temperature_timeseries,
method='rise',
labels=[0, 1],
p_keep=0.1, n_masks=10000,
mask_type=input_train_mean)

assert np.argmax(summer_explanation) == hot_day_index
assert np.argmin(summer_explanation) == cold_day_index
assert np.argmax(winter_explanation) == cold_day_index
assert np.argmin(winter_explanation) == hot_day_index
38 changes: 38 additions & 0 deletions tests/methods/time_series_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np


"""Test case for timeseries xai methods.
This test case is designed to show if the xai methods could provide reasonable results.
In this test case, every test instance is a 28 days by 1 channel array indicating the max temp on a day.
"""


def input_train_mean(_data):
"""Return overall mean temperature of 14."""
return 14


def average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index):
"""Creates a temperature time series of all 14s and a single cold (-2) and hot (30) day."""
temperature_timeseries = np.expand_dims(np.zeros(28), axis=1) + 14
temperature_timeseries[hot_day_index] = 30
temperature_timeseries[cold_day_index] = -2
return temperature_timeseries


def run_expert_model(data):
"""A simple model that classifies a batch of timeseries.
All instances with an average above 14 are classified as summer (0) and the rest as winter (1).
"""
# Make actual decision
is_summer = np.mean(np.mean(data, axis=1), axis=1) > 14

# Create the correct output format
number_of_classes = 2
number_of_instances = data.shape[0]
result = np.zeros((number_of_instances, number_of_classes))
result[is_summer] = [1.0, 0.0]
result[~is_summer] = [0.0, 1.0]

return result
30 changes: 23 additions & 7 deletions tests/test_common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
from tests.utils import run_model


input_data = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]
def test_common_RISE_image_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
input_image = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]

heatmap = dianna.explain_image(run_model, input_image, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_image[0], show_plot=False)


def test_common_RISE_pipeline(): # noqa: N802 ignore case
def test_common_RISE_timeseries_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
heatmap = dianna.explain_image(run_model, input_data, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)
input_timeseries = np.random.random((31, 1))
labels = [0]

heatmap = dianna.explain_timeseries(run_model, input_timeseries, "RISE", labels)[0]
heatmap_channel = heatmap[:, 0]
segments = []
for i in range(len(heatmap_channel) - 1):
segments.append({
'index': i,
'start': i,
'stop': i + 1,
'weight': heatmap_channel[i]})
dianna.visualization.plot_timeseries(range(len(heatmap_channel)), input_timeseries[:, 0], segments, show_plot=False)
Binary file not shown.
Loading

0 comments on commit 7987e51

Please sign in to comment.