Our paper “VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis” was accepted In CVPR 2024. Code, dataset, and checkpoints are available at https://github.com/Luffy03/VoCo. The pre-print paper is available at https://arxiv.org/abs/2402.17300. Congratulations to the authors!
Deep learning has demonstrated outstanding achievements in 3D medical image analysis, yet is heavily hampered by the expensive cost of the required expert annotations. To address this problem, Self-Supervised Learning (SSL) has received significant attention due to its promising ability to learn representations without annotations, which has become an important label-efficient solution in 3D medical image analysis. However, the lack of high-level semantics in pre-training still heavily hinders the performance of downstream tasks. In this paper, we aim to introduce stronger high-level semantics to further improve medical image pre-training.
Existing medical image pre-training methods are mostly based on information reconstructions to learn augment-invariant representations of 3D medical images, which first employ strong data augmentation to the images and then reconstruct the raw information. Specifically, rotate-and-reconstruct proposed to randomly rotate the 3D volumetric images and learn to recover them, which encourages models to learn rotational invariant features. Recent methods were further proposed to restore information among different views of the image. Mask-reconstruct methods are also widely used, which are introduced from MAE and aim to learn representations by masking images and reconstructing the missing pixels. Although promising results have been demonstrated, previous works have proved that the lack of high-level semantics in pre-training will heavily hinder the performance of downstream tasks.
Fig.1 Our observation and motivation.
To this end, we argue that the contextual position priors of 3D medical images should be further exploited. As shown in Fig.1(a), we observe that in 3D medical images, different organs (semantic regions) contain relatively consistent contextual positions with relatively consistent anatomic characteristics (shapes). Thus, the consistency of geometric relations between different organs leads to a potential way for us to learn consistent semantic representations for 3D medical images pre-training. In this paper, we propose a pretext task for contextual position predictions, which aims to encode contextual position priors into model representations and enables us to effectively improve the performance of downstream tasks that require high-level semantics.
Fig.2 The overall framework of VoCo.
The overall framework contains a contextual position prediction branch and a regularization branch. The prediction branch is used to predict the contextual positions between different cropped volumes. Specifically, given an input volume, we first crop it into non-overlap base volumes, which cover the whole input volume. Then, we randomly crop a volume and transform it into the high-dimension feature space using a typical backbone. The goal is to predict the contextual positions between the randomly cropped volumes and base volumes. In this paper, instead of training a linear classifier to predict positions as in previous works, we propose to establish this goal by volume contrast. We develop a loss function \(L_{\text{pred}}\) to supervise the final predictions. In addition, we further use a loss function \(L_{\text{reg}}\) to regularize the feature discrepancy from different bases by enlarging their distance, aiming to learn more discriminative class assignments.
The process of Position label generation is shown in the following figure. When we generate \(n=4 \times 4\) base crops, there will be \(n\) class assignments. Then we calculate the overlap area between a randomly cropped volume and \(n\) base crops. The proportions of the overlap area are then assigned as position labels \(y\), which also range from 0 to 1. Thus, we can easily supervise the model by calculating the distance between the prediction logits \(l\) and position labels \(y\).
Fig.3 The process of Position label generation.
Given an input volume, we first crop it into \(n\) non-overlap base volumes, which cover the whole input volume. We then employ the extracted features \(z\) as class assignments (we call them bases), which present the prototype-level features from different positions. Then, following previous SSL works, a projector with linear layers is used to project \(z\) into latent features \(q\). Then, we randomly crop a volume and transform it into high-dimension feature space as \(p\). The backbone and projector are also used to project the features from the randomly cropped volumes.
Then, we calculate the similarity logits \(l\) between \(q_i\) and \(p\). Specifically, we use cosine similarity to compute \(l\), where \(q_i\) is the projected feature of each base crop. \(l_i\) denotes the similarity between \(q_i\) and \(p\), which ranges from 0 to 1. Intuitively, higher \(l_i\) represents that \(p\) has a higher probability of sharing overlap regions with \(q_i\). Then, we use a softmax function to normalize \(l\) into a probability distribution. The final prediction is supervised by the position labels \(y\). In this way, we can explicitly associate the similarity value with the position information, i.e., \(p\) with higher \(l_i\) is more likely to be located in the region of the \(i\)-th base crop. Thus, instead of training a black-box linear layer, we predict the contextual positions by volume contrast, which is more intuitive and effective. The details can be found in our paper.
We collect 1.6k and 10k CT datasets for pre-training. After pre-training we evaluate the performance on six downstream tasks, as shown below. The details are presented in our main paper.
Table 1. The performance of pre-training datasets.
Fig.4 Overall comparisons on six downstream datasets.
It can be seen that with pre-training, the performance can be consistently improved. Our proposed method also outperforms previous pre-training methods by a clear margin. Some visualization results are shown in the following figure. Detailed results and ablation studies can be found in our main paper.
Fig.5 Visualization results on BTCV dataset.
There are still some significant limitations in our current work: (1) The improvements are still limited in some datasets, which are unseen in the pre-training. (2) When transferring the model to MRI data, the improvements become marginal. (3) For the tumor dataset that requires extensive expert knowledge, self-supervised learning without annotations cannot bring obvious improvements. (4) Some recent supervised pre-training methods have demonstrated promising results using less data than our self-supervised pre-training.
In the future, we will consider extending our work from the following aspects: (1) We will scale up the dataset and model capacity for training a stronger model and benefit the following downstream tasks (a 100k CT pre-training benchmark is coming soon). (2) We will involve more downstream tasks for evaluation. (3) For MRI data, we are also developing an MRI pre-training benchmark. (4) For the tumor dataset, we found that large-scale tumor synthesis can significantly improve performance (a large-scale tumor synthesis and segmentation benchmark will soon be available). (5) We will further consider the omni-supervised pre-training method (self-, semi-, weakly-, and fully-supervised learning) to further boost the performance.
Linshan Wu, Jiaxin Zhuang, and Hao Chen. “VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis”. CVPR 2024