Exploring Explainable AI in the Era of Foundation Models: Initial Efforts by HKUST Smart Lab in Healthcare

Deep learning has achieved significant advancements in various medical tasks such as disease diagnosis, drug discovery, and survival analysis. However, the “black box” nature of deep learning models has raised concerns about their transparency, interpretability, and accountability in real-world clinical applications. Explainable Artificial Intelligence (XAI) aims to provide clear and understandable explanations for decision-making processes, enabling doctors and patients to understand and trust medical decisions. The advent of the Foundation Model era has brought new opportunities for the development of XAI.

The main research approach of HKUST Smart Lab in the field of explainable AI in healthcare is to “optimize XAI with foundation models and build stronger foundation models with XAI.” So far, the lab has made significant progress in this cutting-edge field, with research outcomes published in top international journals and conferences such as Nature Communications, AAAI, ECCV, ICML, and MICCAI. The lab is committed to applying XAI technology to important scenarios such as medical diagnosis and treatment decision-making, providing more reliable intelligent support for doctors and patients.

01. Nature Communications: Development and Validation of a Multimodal Interpretable Model for Ovarian Cancer Diagnosis

Development and validation of an interpretable model integrating multimodal information for improving ovarian cancer diagnosis

OvcaFinder

Ovarian cancer is a heterogeneous disease with the highest mortality rate among gynecological malignancies. Accurate and early diagnosis of ovarian cancer is of great significance. Transvaginal ultrasound (TVUS) is the most commonly used imaging tool for diagnosing adnexal masses. Although several deep learning-based classification systems have been proposed, there is still room for improvement. Firstly, deep learning models are often criticized as “black boxes” due to their lack of transparency and interpretability, making it difficult for radiologists to understand what the model has learned from the training images. Secondly, previous deep learning models did not include readily available clinical variables that may be useful for ovarian cancer diagnosis, such as the serum biomarker CA125. Multiple modalities of information are usually required to make a diagnosis, but there has been no research integrating multimodal information into ovarian cancer risk stratification methods.

This study developed and validated an interpretable model, OvcaFinder, which integrates deep learning-based predictions from ultrasound images, radiologists’ O-RADS scores, and conventional clinical variables to differentiate between benign and malignant ovarian tumors. Specifically, OvcaFinder is based on a random forest (RF) algorithm and utilizes a 5-dimensional vector, including three clinical factors (patient age, tumor diameter, and CA125 concentration), radiologists’ O-RADS scores, and deep learning-based predictions. The RF model is trained using N estimators, with each estimator trained on a bootstrapped sample of 1000 instances from the training set, creating simulated datasets. Each simulated dataset is used to grow a decision tree, resulting in a random forest composed of N different decision trees. The final output is generated using a majority voting algorithm.

The study collected 3972 images of 724 lesions from Sun Yat-sen University Cancer Center between February 2011 and May 2021. The dataset was randomly divided into training, validation, and internal test sets in a 7:1:2 ratio. The external validation set consisted of 2200 images of 387 lesions collected from Chongqing University Cancer Hospital between December 2018 and June 2021. The experimental results showed that OvcaFinder achieved AUCs of 0.978 and 0.947 on the internal and external test sets, respectively, outperforming any single model or radiologist. OvcaFinder assisted radiologists in improving diagnostic performance and reducing false positive rates. In addition to identifying ovarian cancer, OvcaFinder can highlight the most important regions through heatmaps and reveal the impact of each parameter using Shapley values, providing explanations for its predictions.

Authors: Huiling Xiang, Yongjie Xiao, Fang Li, Chunyan Li, Lixian Liu, Tingting Deng, Cuiju Yan, Fengtao Zhou, Xi Wang, Jinjing Ou, Qingguang Lin, Ruixia Hong, Lishu Huang, Luyang Luo, Huangjing Lin, Xi Lin, Hao Chen

Paper Link: Nature Communications

02. AAAI 2024: MICA: Towards Explainable Skin Lesion Diagnosis via Multi-Level Image-Concept Alignment

MICA: Towards Explainable Skin Lesion Diagnosis via Multi-Level Image-Concept Alignment

MICA

Black-box deep learning models have shown strong performance in medical image analysis but have not been fully explored in terms of model interpretability. Existing concept-based interpretability methods typically use concept annotations only at a coarse-grained level, ignoring the fine-grained semantic associations between local regions in medical images and medical concepts. This leads to insufficient utilization of concept information, making it difficult to balance model performance and interpretability when using pre-analysis interpretability frameworks such as concept bottleneck architectures. Therefore, this paper proposes an interpretable disease diagnosis framework that aligns medical images and medical concepts at multiple granularity levels, fully utilizing fine-grained concept information to achieve good diagnostic performance while providing text and visual explanations to improve model interpretability.

The proposed interpretable disease diagnosis framework is divided into two stages. The first stage is multi-granularity alignment of images and concepts, aligning medical images and medical concepts at the global image level, local region level, and concept subspace level. Specifically, at the global image level, features of matching image-concept pairs are aligned. At the local region level, a cross-attention mechanism is used to align local regions of images with concepts. Additionally, concept activation vectors are used to map features to the concept subspace and supervised using concept annotations. The second stage is explainable disease diagnosis, where the image encoder trained in the first stage extracts features from the input image, and a highly interpretable concept bottleneck architecture is used to predict the medical concepts present in the image, which are then used to derive the final disease diagnosis.

This method was validated on three skin datasets with medical concept annotations, showing superior disease diagnosis performance compared to other explainable methods. Additionally, the paper analyzes the model’s interpretability from multiple dimensions, including concept intervention, concept contribution, and visualization, demonstrating that the method can provide text and visual explanations while maintaining high model prediction performance, thus possessing strong interpretability.

Authors: Yequan Bie, Luyang Luo, Hao Chen

Paper Link: AAAI

03. MICCAI 2024: XCoOp: Explainable Prompt Learning for Computer-Aided Diagnosis via Concept-guided Context Optimization

XCoOp: Explainable Prompt Learning for Computer-Aided Diagnosis via Concept-guided Context Optimization

XCoOp

Leveraging the powerful representation capabilities of large vision-language models (VLMs) to perform various downstream tasks has garnered increasing attention. Among them, soft prompt learning has become a representative method for efficiently applying large VLMs (such as CLIP) to tasks like image classification. However, most existing prompt learning methods learn text tokens that are difficult to interpret, failing to meet the strict interpretability requirements of XAI in high-risk scenarios such as healthcare. Therefore, this paper proposes an explainable prompt learning framework that fully utilizes medical knowledge by aligning the semantics of images, learnable prompts (soft/learnable prompt), and clinical prompts at multiple granularities. Additionally, this method addresses the high cost of concept annotation by acquiring domain knowledge from large language models. This method emphasizes the effectiveness of foundation models in promoting research in explainable AI.

The method first generates clinical concepts corresponding to target disease categories using a large language model (such as GPT-4) and composes these concepts into medical (concept) prompts. These medical prompts guide the training of the soft prompts during the prompt learning process, specifically aligning learnable prompts and medical prompts at the token and prompt dimensions. Furthermore, semantic alignment of images and prompts is performed at both global and local levels, ultimately enabling the efficient adaptation of the large VLM CLIP to downstream medical tasks through the learned prompts.

The method was validated on four datasets with and without medical concept annotations (including skin lesion datasets and chest X-ray datasets), showing superior performance compared to other prompt learning methods. Additionally, the paper analyzes the model’s interpretability from multiple dimensions, including the medical semantics of soft prompts, knowledge intervention, and visualization, demonstrating that the method provides sufficient interpretability while improving model performance.

Authors: Yequan Bie, Luyang Luo, Zhixuan Chen, Hao Chen

Paper Link: arXiv

04. ECCV 2024: Explain via Any Concept: Concept Bottleneck Model with Open Vocabulary Concepts

Explain via any concept: Concept Bottleneck Model with open vocabulary concepts

The Concept Bottleneck Model (CBM) is a “designed-to-be-interpretable” framework that first predicts a set of interpretable concepts and then predicts class labels based on the given concepts. Existing CBMs use a fixed set of concepts for training (concepts are either annotated by the dataset or queried from a language model). However, in practice, this closed-world assumption is unrealistic, as users may want to understand the role of any concept in decision-making after model deployment. Inspired by the recent success of vision-language pre-trained models (such as CLIP) in zero-shot classification, this paper proposes “OpenCBM” to enable CBM to use concepts from any open vocabulary.

OpenCBM

During training, the visual encoder of the model is aligned with the feature space of CLIP’s visual encoder, and the classification head of the trained model is reconstructed using high-dimensional features of any user-specified concepts obtained from CLIP’s text encoder. The remaining classification head parameters that cannot be reconstructed are supplemented by iteratively selecting the most similar concepts from any concept set.

OpenCBM

Experimental results show that the prediction accuracy on the CUB-200-2011 and CIFAR100 datasets significantly exceeds that of previous best CBM models, qualitatively demonstrating the interpretability of the obtained concepts concerning downstream tasks.

Authors: Andong Tan, Fengtao Zhou, Hao Chen

05. ICML 2024: Post-hoc Part-prototype Networks

Post-hoc Part-prototype Networks

Post-hoc Part-prototype Networks

Post-hoc interpretability methods (such as Grad-CAM) are popular because they do not affect the performance of trained models. However, they mainly reveal “where” the model “sees” in the image for a given input but fail to explain “what” the model is looking for (e.g., for classifying a bird image as a Scott’s Oriole, what features are important?). Existing part-prototype networks use part prototypes (e.g., wing and head features of a Scott’s Oriole) to answer “where” and “what” but often fall short in accuracy compared to corresponding black-box models. Therefore, a natural question arises: Can we construct a network that answers “where” and “what” post-hoc while ensuring model performance?

This paper proposes the first post-hoc part-prototype network by decomposing the classification head of a trained model into a set of interpretable part prototypes. Specifically, an unsupervised prototype discovery and refinement strategy is proposed to ensure that the discovered part prototypes can accurately reconstruct the classification head and possess interpretability. The strategy involves using non-negative matrix factorization for initial prototype discovery, prototype scaling for preliminary reconstruction of the classification head, and dynamic residual parameter allocation for prototype refinement, resulting in a set of prototypes that can accurately reconstruct the classification head of the trained model.

Experimental results on the CUB-200-2011 dataset validate that the obtained prototypes exhibit better consistency and stability, and surpass existing models in classification accuracy across multiple baseline models.

Authors: Andong Tan, Fengtao Zhou, Hao Chen

Paper Link: OpenReview

06. MICCAI 2024: Concept-Attention Whitening for Interpretable Skin Lesion Diagnosis

Concept-Attention Whitening for Interpretable Skin Lesion Diagnosis

CAW

The black-box nature of deep learning models has raised concerns about their interpretability in real-world clinical applications. Therefore, there is an urgent need to develop XAI technologies to enhance the transparency and understandability of decision-making processes. In the medical field, concepts such as lesions or abnormal features are key evidence for deriving diagnostic results. For example, blue-white veil, atypical pigment network, and irregular streaks are important concepts for diagnosing melanoma skin diseases. Existing concept-based XAI models mainly rely on independently occurring concepts and require fine-grained concept annotations such as bounding boxes. However, multiple concepts usually appear simultaneously in medical images, and fine-grained concept annotations are difficult to obtain.

This paper aims to interpret the representations in deep neural networks by aligning the axes of the latent space with known concepts of interest. A novel Concept-Attention Whitening (CAW) framework is proposed, consisting of a disease diagnosis branch and a concept alignment branch: (1) In the disease diagnosis branch, given a disease dataset, a convolutional neural network with a CAW layer is trained to perform skin lesion diagnosis. The CAW layer aims to decorrelate features through whitening transformation and assign specific conceptual meanings to feature dimensions through orthogonal transformation. (2) In the concept alignment branch, a concept dataset is used to calculate an orthogonal matrix based on concept features. A weakly supervised concept mask generator is introduced to highlight the most relevant local regions for a concept to obtain representative concept features, and the orthogonal matrix is calculated by solving an optimization problem.

Experiments on two public skin lesion diagnosis datasets, Derm7pt and Skincon, show that the proposed CAW method achieves state-of-the-art disease diagnosis performance, with AUC scores of 0.886 and 0.804, respectively. At the same time, CAW demonstrates superior concept detection performance, providing text explanations describing image lesions based on concept activation values in the features and generating visual explanations to highlight specific concept regions, greatly enhancing the interpretability of the disease diagnosis process.

Authors: Junlin Hou, Jilan Xu, Hao Chen

Paper Link: arXiv