This repository is a limited implementation of Sparsity Probe: Analysis tool for Deep Learning Models by I. Ben-Shaul and S. Dekel (2021).
git clone https://github.com/idobenshaul10/SparsityProbe.git
pip install -r requirements.txt
torch==1.7.0
umap_learn==0.4.6
matplotlib==3.3.2
tqdm==4.49.0
seaborn==0.11.0
torchvision==0.8.1
numpy==1.19.2
scikit_learn==0.24.2
umap==0.1.1
The first step of using this Repo should be to look at this example: CIFAR10 Example. In this example, we demonstrate running the Sparsity-Probe on a trained Resnet18 on the CIFAR10 dataset, at selected layers.
Create a new environment in the environments
directory, inheriting from BaseEnviorment
. This enviorment should include the train and test datasets(including the matching transforms), the model layers we want to test the alpha-scores on(see cifar10_env
example), and the trained model.
It is possible to train a basic model with the train.py script, which uses an environment to load the model and the datasets.
Example Usage:
python train/train_mnist.py --output_path "results" --batch_size 32 --epochs 100
Done using the DL_smoothness.py script.
Arguments:
trees
- Number of trees in the forest.
depth
- Maximum depth of each tree.
batch_size
- batch used in the forward pass(when computing the layer outputs)
env_name
- enviorment which is loaded to measure alpha-scores on
epsilon_1
- the epsilon_low used for the numerical approximation. By default, epsilon_high is
inited as 4*epsilon_low
only_umap
- only create umaps of the intermediate layers(without computing alpha-scores)
use_clustering
- run KMeans on intermediate layers
calc_test
- calculate test accuracy(More metrics coming soon)
output_folder
- location where all outputs are saved
feature_dimension
- to reduce computation costs, we compute the alpha-scores on the features after a dimensionality reduction technique has been applied. As of now, if the dim(layer_outputs)>feature_dimension, the TruncatedSVD is used to
reduce dim(layer_outputs) to feature_dimension. Default feature_dimension is 2500.
Result plots can be created using this script.
Our pretrained CIFAR10 Resnet18 network used in the example is taken from This Repo.
This repository is MIT licensed, as found in the LICENSE file.