[MICCAI 2024 Paper & HKUST CSE Best FYP] Aligning Medical Image Analysis Tasks with General Knowledge from Large Language Models

Recently, a study on medical vision-language models by HKUST Smart Lab has been accepted for publication at the International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI) 2024. This paper is among the first batch of early accepted papers for MICCAI 2024 and has been awarded the 2024 Best Final Year Project (CSE Best FYP Award) by the Department of Computer Science and Engineering at The Hong Kong University of Science and Technology. The research proposes a visual symptom-guided prompt learning framework based on large language models to enhance the generalization and interpretability of vision-language models.

Abstract

Pre-trained vision-language models (VLMs) like CLIP have revolutionized visual representation learning using natural language supervision, demonstrating strong generalization capabilities. In this work, we propose ViP, a visual symptom-guided prompt learning framework for medical image analysis that facilitates knowledge transfer from CLIP. ViP consists of two modules: the Visual Symptom Generator and the Dual-Prompt Network. Specifically, the Visual Symptom Generator aims to extract interpretable visual symptoms from pre-trained large language models, while the Dual-Prompt Network utilizes these visual symptoms to guide the training of two learnable prompt modules, namely the Context Prompt and the Merge Prompt. These modules, combined with VLMs, effectively apply our framework to medical image analysis. Experimental results show that ViP outperforms state-of-the-art methods on the Derm7pt and Pneumonia datasets. The code is available at: https://github.com/xiaofang007/ViP

Introduction

Medical image analysis plays a crucial role in healthcare. With the advancement of deep learning technology, computer-assisted medical image analysis has achieved significant success in numerous scenarios. Current methods generally adopt a supervised learning paradigm, requiring a large amount of labeled data for model training. However, this paradigm relies on manual annotation of medical images, which is time-consuming and labor-intensive.

The emergence of vision-language models (VLMs) has made it possible to transfer knowledge from large-scale pre-trained models to specific medical image analysis tasks with limited data. For example, CLIP is pre-trained using contrastive learning on 400 million image-text pairs. Although CLIP has shown great potential for transfer learning across various tasks, directly applying it to the medical domain remains challenging. This is because CLIP is primarily pre-trained on web-scraped data containing natural image-text pairs, whereas medical images often involve abstract medical terminology that CLIP may struggle to encode effectively. Inspired by recent works, we propose addressing the interpretability challenge by converting abstract medical terms into visual symptoms that are shared across to both natural and medical domains (e.g., color, shape, and texture). This allows VLMs to learn to align image features with interpretable visual features. This process is also similar to how doctors diagnose diseases based on relevant visual features observed in medical images.

In this paper, we propose the ViP, a novel visual symptom-guided prompt learning framework to facilitate knowledge transfer from CLIP. The framework comprises two modules: the Visual Symptom Generator and the Dual-Prompt Network. The Visual Symptom Generator generates visual symptoms by querying pre-trained large language models (LLMs), which serve as text inputs to the Dual-Prompt Network. The Dual-Prompt Network enhances CLIP’s generalization capability by training two learnable prompt modules, the Context Prompt and the Merge Prompt. The Context Prompt refines visual symptoms by incorporating contextual information of medical tasks, while the Merge Prompt aggregates textual features of visual symptoms. We evaluate the framework on public datasets Pneumonia and Derm7pt. Experimental results demonstrate that ViP outperforms state-of-the-art methods, highlighting the effectiveness of each module in our framework.

Methods

Visual Symptom-Guided Prompt Learning Framework

Visual Symptom-Guided Prompt Learning Framework

Figure 1. Visual Symptom-Guided Prompt Learning Framework. It consists of the Visual Symptom Generator and the Dual-Prompt Network. Visual symptoms predicted by the Visual Symptom Generator are used as inputs to the downstream Dual-Prompt Network (highlighted by blue dashed lines).

The workflow of the Visual Symptom-Guided Prompt Learning Framework is shown in Figure 1. We consider an input image \(x\) and a set of disease labels \(C = \{c_1, c_2, \ldots, c_n\}\). Let \(N\) denote the total number of disease categories, and \(N = n\). The entire process first passes the input image \(x\) through a pre-trained visual encoder in the Dual Prompt Network to obtain a feature vector \(f\). Meanwhile, for each disease category, the Visual Symptom Generator generates multiple visual symptoms. These visual symptoms are then transformed in the Context Prompt module to create text inputs for the Dual-Prompt Network. These text inputs are processed by the pre-trained text encoder to compute textual features of each visual symptom. Next, the Merge Prompt module aggregates these textual features to obtain a representative feature \(s^c\) for the visual symptoms of disease category \(c\). By traversing all disease categories, we can obtain the set of representative features for all diseases \(S = \{s^{(c_1)}, s^{(c_2)}, \ldots, s^{(c_n)}\}\). Finally, the predicted disease category is the one whose visual symptom representative feature has the highest cosine similarity with the input image feature \(f\).

Visual Symptom Generator

The Visual Symptom Generator aims to generate a comprehensive set of visual symptoms for each disease category. Inspired by the extensive knowledge and query capabilities of large language models (LLMs), we propose a two-stage process to construct this set by prompting large language models (e.g., GPT-4). First, we use text-only prompts to obtain a coarse set of visual symptoms. We input the following text to the language model:

Q: I am going to use CLIP, a vision-language model to detect {category} in {modality}. What are useful medical visual features for diagnosing {category}? Please list in bullet points and explain in plain words that CLIP understands. Avoid using words such as {category}.

Here, {category} needs to be replaced with the disease category, and {modality} needs to be replaced with the image modality, such as dermoscopic images. This prompt is designed to provide GPT-4 with enough background information and ensure its answers can be understood by CLIP. Next, we refine the set using GPT-4’s visual question-answering capabilities. For each disease category, we ask GPT-4 the following query with a few image samples:

Q: Please provide visual features regarding color, shape, and texture of this {category} image, which contains 16 sub-images.
This yields a set of visual features from the images. By intersecting the rough visual symptom set obtained in the first stage with the visual feature set obtained from images in the second stage, we obtain the refined set of visual symptoms.

Dual-Prompt Network

The Dual-Prompt Network is built on CLIP. We freeze CLIP’s image and text encoders to retain the general knowledge from large-scale pre-training data. Unlike traditional CLIP-based methods that rely on category names as text prompts, we use visual symptoms generated by the Visual Symptom Generator, enabling the model to align image features with visual descriptions. However, the generalization capability of this framework remains limited. This limitation arises from the potential deviation of symptoms generated by LLMs from CLIP’s text input format and the challenge of effectively aggregating multiple visual symptoms into disease representations. Therefore, we further propose two learnable prompt modules: Context Prompt and Merge Prompt, to enhance the model’s generalization capability.

Context Prompt Module: In CLIP’s text input, context plays a crucial role alongside category names. For example, CLIP adds the context prompt {a photo of a} before the category name. Similarly, to capture the context of medical tasks, we also want to add contextual information before visual symptoms. However, due to the complex sentence structure of visual symptoms generated by LLMs, it is challenging to design templates manually. We introduce a set of learnable parameters \(\{p_i\}_{i=1}^M\), where \(p_i \in \mathbb{R}^d, i=1,2,\ldots,M\), and \(d\) is the dimension of text embeddings, to automatically learn the context of medical tasks in a data-driven manner. Specifically, given a category \(c \in C\) and a text embedding \(e_d\) of a visual symptom, the final text embedding \(T\) of the text encoder is the concatenation of the learnable parameters and the text embedding of the visual symptom, represented as \(T = \text{Concat}(p_1, p_2, \ldots, p_M, e_d)\).

Merge Prompt Module: After the text encoder processes the visual symptoms, the next step is to merge multiple visual symptoms into a single representation. We introduce learnable parameters for each disease category to learn representative features of the disease. Specifically, given a category \(c \in C\), the text feature matrix \(T = [T_1^c, T_2^c, \ldots, T_k^c]^T\) processed by the text encoder, where \(T \in \mathbb{R}^{k \times d}\), and \(d\) is the dimension of text embeddings, and learnable parameters \(g \in \mathbb{R}^d\), we first project \(g\) and \(T\) to query vector \(Q \in \mathbb{R}^d\) and key-value vector \(K \in \mathbb{R}^{k \times d}\) with different weights \(W_q \in \mathbb{R}^{d \times d}\) and \(W_k \in \mathbb{R}^{d \times d}\), represented as: \(Q = gW_q, K = TW_k\) The representative feature \(s^c\) of visual symptoms can be calculated using the learnable parameters \(g\) and the text feature matrix \(T\), represented as: \(s^c = g + \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)T\) After obtaining the representative features of visual symptoms for all disease categories, we jointly optimize the Context Prompt module and Merge Prompt module using cross-entropy loss, represented as: \(L_{ce} = -\log\left(\frac{\exp(f \cdot s^{(c_y)} / \gamma)}{\sum_{i=1}^N \exp(f \cdot s^{(c_i)} / \gamma)}\right)\) where \(c_y\) represents the true disease category, and \(\gamma\) is the temperature parameter.

Experiments

Effectiveness of Interpretable Visual Symptoms

We conducted experiments on two public datasets: Pneumonia and Derm7pt. Detailed information about these datasets can be found in our paper.

Effectiveness of Interpretable Visual Symptoms Figure 2. (I) Zero-shot CLIP uses category names or visual symptoms as text input. (II) Diagnostic process based on cosine similarity between image features and visual symptom features.

We conducted a zero-shot experiment to evaluate the effectiveness of visual symptoms in disease diagnosis while also providing explanations for the decisions. As shown in Figure 2(I), compared to the zero-shot CLIP model, ViP achieved an accuracy improvement of 0.44% and 18.73% and an F1 score gain of 1.58% and 10.98% on Pneumonia and Derm7pt, respectively. This indicates that LLMs can provide useful knowledge for the medical domain. Additionally, we analyzed cases where our framework correctly predicted disease categories while the CLIP model failed, as shown in Figure 2(II). Our framework improved diagnostic accuracy due to the higher similarity between the image and the correct category’s visual symptoms.

Comparison with Other Methods

Comparison with other methods on Pneumonia and Derm7pt datasets Table 1. Comparison with other methods on Pneumonia and Derm7pt datasets.

We further compared ViP with several prompt learning-based models to evaluate its generalization capability. As shown in Table 1, ViP achieved the highest accuracy of 86.69% and 81.11% and F1 scores of 84.94% and 77.3% on Pneumonia and Derm7pt, respectively, demonstrating our method’s strong generalization capability. Moreover, compared to fully supervised learning models, ViP achieved competitive results on Pneumonia and performed even better on Derm7pt with fewer training data, proving ViP’s strong generalization capability in low-resource settings.

Effectiveness of Each Module

Ablation study results

Table 2. Ablation study results. ‘Context’ and ‘Merge’ represent the Context Prompt module and Merge Prompt module, respectively. ‘Max’ and ‘Mean’ represent taking the maximum and average similarity between image features and visual symptom features during aggregation, respectively.

We explored the effectiveness of each module in ViP, as shown in Table 2. Compared to the zero-shot baseline, the integration of the Context Prompt module and Merge Prompt module showed substantial improvements, demonstrating the importance of learning the context of medical tasks and effectively aggregating visual symptoms. Additionally, compared to non-parametric aggregation methods such as average and maximum functions, our proposed Merge Prompt module performed better on both datasets, further validating the effectiveness of our approach.

Reliability of Knowledge

Ablation study comparing different types of knowledge

Figure 3. Ablation study comparing different types of knowledge.

We conducted another experiment to verify that visual symptoms generated by LLMs provide useful knowledge for VLMs to generalize to the medical domain. As shown in Figure 3, we replaced visual symptoms of moles with three types of knowledge: 1) Out-of-domain knowledge, involving visual features unrelated to the medical domain, such as food descriptions. 2) Useless knowledge, referring to descriptions related to our target disease but not providing diagnostically useful visual features, such as descriptions of skin structure. 3) Incorrect knowledge, providing incorrect visual symptoms. In this experiment, we replaced certain words in the descriptions with their antonyms to create misleading descriptions of mole visual symptoms. Compared to other types, LLM-generated knowledge achieved the best performance, indicating that accurate visual symptoms contribute to generalization in the medical domain.

Conclusion

This paper proposes a novel Visual symptom-guided Prompt learning framework (ViP) that effectively transfers knowledge from vision-language models to medical image analysis. We utilize pre-trained large language models to generate visual symptoms to guide CLIP in aligning image features with visual features. Additionally, ViP integrates two learnable prompt modules, the Context Prompt and the Merge Prompt, to further enhance generalization capability. Experimental results demonstrate the effectiveness of each module and the superior performance of our framework.