This repository provides a TensorFlow implementation of the following paper:
A Quantitatively Interpretable Model for Alzheimer’s Disease Prediction using Deep Counterfactuals
Kwanseok Oh1, Da-Woon Heo1, Ahmad Wisnu Mulyadi2, Wonsik Jung2, Eunsong Kang2, Kunho Lee3, 4, and Heung-Il Suk1, 2
(1Department of Artificial Intelligence, Korea University)
(2Department of Brain and Cognitive Engineering, Korea University)
(3Department of Biomedical Science and Gwangju Alzheimer’s & Related Dementia Cohort Research Center, Chosun University)
(4Korea Brain Research Institute)
Published in Medical Imaging meets NeurIPS (MedNeurIPS2022): "Quantifying Explainability of Counterfactual-Guided MRI Feature for Alzheimer’s Disease Prediction"
Official Journal Version: https://arxiv.org/pdf/2310.03457.pdfAbstract: Deep learning (DL) for predicting Alzheimer’s disease (AD) has provided timely intervention in disease progression yet still demands attentive interpretability to explain how their DL models make definitive decisions. Recently, counterfactual reasoning has gained increasing attention in medical research because of its ability to provide a refined visual explanatory map. However, such visual explanatory maps based on visual inspection alone are insufficient unless we intuitively demonstrate their medical or neuroscientific validity via quantitative features. In this study, we synthesize the counterfactual-labeled structural MRIs using our proposed framework and transform it into a gray matter density map to measure its volumetric changes over the parcellated region of interest (ROI). We also devised a lightweight linear classifier to boost the effectiveness of constructed ROIs, promoted quantitative interpretation, and achieved comparable predictive performance to DL methods. Throughout this, our framework produces an “AD-relatedness index” for each ROI and offers an intuitive understanding of brain status for an individual patient and across patient groups with respect to AD progression.
- We propose a novel methodology to develop fundamental scientific insights from a counterfactual reasoning-based explainable learning method. We demonstrate that our proposed method can be interpreted intuitively from the clinician’s perspective by converting counterfactual-guided deep features to the quantitative volumetric feature domain rather than directly inspecting DL-based visual attributions.
- We achieved similar or better performance than DL-based models by designing a shallow network of lightweight counterfactual-guided attentive feature representation and a linear classifier (LiCoL) with the AD-effect ROIs considered to be the distinctive AD-related landmarks via counterfactual-guided deep features.
- By exploiting our proposed LiCoL, we provide a numerically interpretable AD-relatedness index for each patient as well as patient groups with respect to anatomical variations caused by AD progression.
Visualization of a normalized AD-relatedness index over the group-wise (first column) and individuals (second and third columns)
To download Alzheimer's disease neuroimaging initiative dataset
Mode: #0 Learn, #1 Explain
- Learn: pre-training the predictive model
CMG_config.py --mode="Learn"
- Set the mode as a "Learn" to train the predictive model
- Explain: Counterfactual map generation using a pre-trained diagnostic model
CMG_config.py --mode="Explain" --dataset=None --scenario=None
- Change the mode from "Learn" to "Explain" on Config.py
- Set the classifier and encoder weight for training (freeze)
- Set the variables of dataset and scenario for training
- AD-effect map acquisition based on manipulated real-/counterfactual-labeled gray matter density maps
AD-effect Map Acquisition.ipynb
- This step for the AD-effect map acquisition was implemented by using the Jupyter notebook
- Execute markdown cells written in jupyter notebook in order
- LiCoL
LiCoL_ALL.py --datatset=None --scenario=None --data_path==None
- Set the variables of dataset and scenario for training
- For example, dataset="ADNI" and scenario="CN_AD"
- Modify the data path for uploading the dataset (=line 234)
If you find this work useful for your research, please cite the following paper:
@article{oh2023quantitatively,
title={A Quantitatively Interpretable Model for Alzheimer's Disease Prediction Using Deep Counterfactuals},
author={Oh, Kwanseok and Heo, Da-Woon and Mulyadi, Ahmad Wisnu and Jung, Wonsik and Kang, Eunsong and Lee, Kun Ho and Suk, Heung-Il},
journal={arXiv preprint arXiv:2310.03457},
year={2023}
}
This work was supported by the Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) No. 20220-00959 ((Part 2) Few-Shot Learning of Causal Inference in Vision and Language for Decision Making) and No. 20190-00079 (Department of Artificial Intelligence (Korea University)). This study was further supported by KBRI basic research program through Korea Brain Research Institute funded by the Ministry of Science and ICT (22-BR-03-05).