⭐ Star us on GitHub — it helps!!
PyTorch implementation for Counterfactual Explanation Based on Gradual Construction for Deep Networks
You will need a machine with a GPU and CUDA installed.
Then, you prepare runtime environment:
pip install -r requirements.txt
If you get an error message ModuleNotFoundError: No module named 'urllib3'
on conda enviroment, please try to install spacy like this:
conda install spacy==2.3.2
Run the collowing command to get a counterfactual explanation for MNIST data.
python main.py --dataset=mnist --model_path=./models/saved/mnist_cnn.pt --data_path=example/MNIST/0.png --d=4 --target_prob=0.9
Results are saved in result folder.
For IMDB dataset, you should download 'en' model. Type following command.
python -m spacy download en
Run the collowing command to get a counterfactual explanation for IMDB data.
python main.py --dataset=imdb --model_path=./models/saved/tut4-model.pt --data_path="This film is good" --d=1 --target_prob=0.9 --n_iter=700
Run the collowing command to get a counterfactual explanation for HELOC data.
python main.py --dataset=heloc --model_path=./models/saved/MLP_pytorch_HELOC_allRemoved.pt --data_path=./example/HELOC/1.csv --d=1 --target_prob=0.7
The target probability over 0.7 is not allowed because of the pre-trained model capacity.
Arguments:
dataset
- Choose the experiment dataset- Available list: ['mnist','imdb','heloc']
data_path
- Input data (path)l2_coeff
- Coefficient of the l2 regularizationtv_beta
- Exponential number of total variation (TV) regularizationtv_coeff
- Coefficient of the TV regularizationn_iter
- Iteration numberlr
- Learning ratetarget_class
- Choose the target class- 0: a class that has the first highest proability from an original data
- 1: a class that has the second highest proability from an original data
- 2: a class that has the third highest proability from an original data
target_prob
- Target probability of the target classd
- Determine size of maskmodel_path
- Saved model path- Available list=['mnist_cnn.pt',tut4-model.pt','MLP_pytorch_HELOC_allRemoved.pt']
✅ Check my blog!! HERE