[MICCAI 2024 Paper] MambaMIL: Enhancing Long Sequence Modeling with Sequence Reordering in Computational Pathology

Recently, a study on Computational Pathology by HKUST Smart Lab has been accepted at the International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI) 2024. This paper is among the first batch of early acceptances at MICCAI 2024. The research incorporates the Selective Scan Space State Sequential Model (Mamba) in Multiple Instance Learning, demonstrating the ability to comprehensively understand and perceive long sequences of instances from Whole Slide Images. The related code is open-sourced at https://github.com/isyangshu/MambaMIL. The paper can be accessed at https://arxiv.org/abs/2403.06800. The following is the detail of the paper.

Abstract

Multiple Instance Learning (MIL) has emerged as a dominant paradigm to extract discriminative feature representations within Whole Slide Images (WSIs) in computational pathology. Despite driving notable progress, existing MIL approaches suffer from limitations in facilitating comprehensive and efficient interactions among instances, as well as challenges related to time-consuming computations and overfitting. In this paper, we incorporate the Selective Scan Space State Sequential Model (Mamba) in Multiple Instance Learning (MIL) for long sequence modeling with linear complexity, termed as MambaMIL. By inheriting the capability of vanilla Mamba, MambaMIL demonstrates the ability to comprehensively understand and perceive long sequences of instances. Furthermore, we propose the Sequence Reordering Mamba (SR-Mamba) aware of the order and distribution of instances, which exploits the inherent valuable information embedded within the long sequences. With the SR-Mamba as the core component, MambaMIL can effectively capture more discriminative features and mitigate the challenges associated with overfitting and high computational overhead. Extensive experiments on two public challenging tasks across nine diverse datasets demonstrate that our proposed framework performs favorably against state-of-the-art MIL methods.

Introduction

The digitalization of pathological images into Whole Slide Images (WSIs) has paved the way for computer-aided analysis in computational pathology. However, employing deep learning methods for WSI analysis encounters unique challenges, primarily due to the high resolution of WSIs and the lack of pixel-level annotations. To address these issues, Multiple Instance Learning (MIL) has arisen as an ideal solution, where each WSI is represented as a “bag” and partitioned into a sequence of tissue patches termed “instances”. With at least one instance being positive, the bag is classified as positive, otherwise negative.

The most widely used paradigm of MIL involves converting instances into low-dimensional features using pre-trained models, followed by aggregating these features into bag-level representations for subsequent analysis. Under this paradigm, MIL conceptualizes WSI analysis as a long sequence modeling problem, aiming to model the correlation between instances as well as overall contextual information within the entire bag to capture discriminative information. Despite the impressive performance, there remain several issues in existing MIL methods. Attention-based methods primarily focus on instance-level information based on independent and identical distribution hypotheses. However, these methods neglect the contextual relationships among instances, resulting in inadequate representations of WSIs. Additionally, several methods utilize Transformer for its capability to explore mutual-correlations between instances and model long sequences. Nonetheless, they face significant performance bottlenecks due to extensive computations and overfitting. Overall, the existing methods have limitations in comprehensively mining the contextual information within long sequences, which hinders their performance.

Recently, Structured State Space Sequence (S4) has been introduced as an efficient and effective architecture to address the bottleneck concerning long sequence modeling. Furthermore, Selective Scan Space State Sequential Model, namely Mamba, advances S4 in discrete data modeling by employing an input-dependent selection mechanism and a hardware-aware algorithm, which enables Mamba to achieve linear complexity without sacrificing global receptive fields. However, for inherently non-sequential visual data, the direct application of Mamba to a patchified and flattened image would inevitably lead to a constraint in the receptive fields. This limitation stems from the fact that Mamba solely permits interactions between each patch and previously scanned positions, precluding the estimation of relationships with unscanned patches. Unlike typical visual modalities, WSIs contain scattered and scarce positive patches that exhibit weak spatial correlation, which makes them highly suitable for leveraging the robust sequential modeling capabilities of Mamba. Recently, S4MIL introduces S4 model to WSI analysis as a multiple instance learner for instance sequences, which demonstrates the effectiveness of SSM in capturing long-range dependencies. Note that it directly adopts the S4 model without fully considering the unique characteristics of WSIs, resulting in sub-optimal results.

Motivated by the above observations, we propose an efficient and effective benchmark MIL model (MambaMIL) with the following contributions: (1) We incorporate the Mamba framework in MIL to tackle the challenges associated with long sequence modeling and overfitting, marking the first application of Mamba in computational pathology. (2) We propose the Sequence Reordering Mamba (SR-Mamba) aware of the order and distribution of instances, which excels at capturing long-range dependencies among scattered positive instances. As the core component of MambaMIL, SR-Mamba is tailored to learn the correlations between instances in both sequential ordering and transpositional ordering, which significantly enhances the capability of the original Mamba to capture more discriminative features. (3) To investigate the effectiveness of MambaMIL, we conduct comprehensive experiments including overall comparison and ablation studies on two challenging tasks across nine datasets, which demonstrates that MambaMIL can achieve superior performance against the state-of-the-art.

Overview of MambaMIL

To efficiently capture the comprehensive contextual information within long sequences of instances, we introduce a novel approach, MambaMIL, by integrating the Mamba framework into MIL, as illustrated in Fig.1. Inheriting the attributes of Mamba, MambaMIL enables each instance to interact with any of the previously scanned instances through a compressed hidden state, effectively reducing the computation complexity. Specifically, given a WSI, we partition the tissue regions into a sequence of \(L\) patches \(\{p_1, p_2,..., p_L\}\), followed by mapping all the patches into instance features \(X\in\mathbb{R}^{L\times D}\) by Feature Extractor, where \(D\) refers to the feature dimension. Subsequently, the input \(X\) is passed through Linear Projection to reduce the dimension. The output is then fed into a series of stacked SR-Mamba modules, which are responsible for modeling long sequences. Finally, we utilize the Aggregation module to obtain bag-level representations for downstream tasks.

MambaMIL Overview

To tackle the restricted receptive fields, we devise the Sequence Reordering Mamba (SR-Mamba) aware of the order and distribution of instances, which exploits the inherent valuable information embedded within the instances. As illustrated in Fig.1. , considering the scattered and scarce positive patches, we establish parallel SSM-based branches upon vanilla Mamba to enhance long sequence modeling. SR-Mamba models two long sequences with distinct sequence orderings, each associated with a unique compressed hidden state, facilitating the learning of more discriminative features.

In detail, given instance features \(X\in\mathbb{R}^{L \times D}\), we first partition the sequence of instances into non-overlapping segments of size \(R\), and obtain \(N=L/R\) segments from the entire sequence. For sequences whose lengths are not divisible by \(R\), we pad them with zeros for subsequent reordering. Then the \(X\) is fed into two independent branches. For the first branch, we preserve the original ordering of \(X\), which is fed to the subsequent Casual Convolution Layer and State Space Model (SSM) for sequence modeling. The entire process can be formulated as: \(X' = \operatorname{Norm}(X), \quad Y = \operatorname{SSM}(\operatorname{SiLU}(\operatorname{Conv1D}(\operatorname{Linear}(X'))).\) Then the \(X\) is also used to generate the gating value for \(Y\) obtained from SSM, \(Z = \operatorname{SiLU}(\operatorname{Linear}(X')),\quad X'' = Z\odot Y.\) For the second branch, we propose a Sequence Reordering operation as the core component of SR-Mamba. Specifically, the input instance features are reshaped into a 2-D feature map, \(X\in\mathbb{R}^{L \times D} \xrightarrow{} X'\in\mathbb{R}^{R \times N \times D}\). We then sample instances from each non-overlapping segment successively along the second dimension of \(X'\), which can be regarded as a permutation and rearrangement operation. By performing this, we generate the instance features \(X_r\) with the new ordering, which can be utilized to embed more discriminative features by the inherent position-sensitive characteristic of Mamba.

Sequence Reordering

The entire Sequence Reordering operation is depicted in Fig.2. Then we utilize the subsequent Casual Convolution Layer and State Space Model to model \(X_r\), \(X_r' = \operatorname{Norm}(X_r'), \quad Y_r = \operatorname{SSM}(\operatorname{SiLU}(\operatorname{Conv1D}(\operatorname{Linear}(X_r' ))).\) For the enhanced \(X_r'\), we rearrange the sequences into the original ordering through partitioning and permutation operations, and gate the feature by \(Z\), \(Y_r' = \psi(Y_r ),\quad X_r'' = Z\odot Y_r'\) where \(\psi\) denotes Sequence Restoration operation. After modeling the long sequences with distinct orderings, we can obtain two discriminative instance features \(X''\) and \(X_r''\), and aggregate them to obtain \(X_{output}\). We devise the aggregation operation as an element-wise addition of the two features, \(X_{\text{output}} = \operatorname{Linear}(X''+X''_r ) + X.\) Distinct from the original Mamba, we maintain the sequential ordering and distribution, while generating new ordering of the instances from a global perspective for feature re-embedding.

Building upon the vanilla Mamba, SR-Mamba is tailored to robustly comprehend and perceive lengthy sequences of instances that are partitioned from WSIs. Built on stacked SR-Mamba modules, MambaMIL is capable of modeling long-range dependencies with linear complexity, resulting in effective model generalization.

Experiments

To verify the effectiveness of our proposed MambaMIL, we conduct thorough experiments on nine datasets using two distinct feature sets: ResNet-50 pre-trained with ImageNet and PLIP pre-trained with OpenPath.

Survival Prediction

We conduct comprehensive experiments on seven public cancer datasets (BLCA, BRCA, COADREAD, KIRC, KIRP, LUAD, and STAD) from TCGA, containing WSIs annotated with survival outcomes. To reduce the impact of data split on model evaluation, we implement a 5-fold cross-validation approach, partitioning the data into training and validation subsets in a 4:1 ratio. We use the cross-validated Concordance Index (C-Index), along with its standard deviation (std), to assess the model’s effectiveness.

As presented in Table 1, we conduct comparison experiments with two distinct feature settings on seven TCGA cancer datasets. The results demonstrate that MambaMIL achieves the best performance on all benchmarks compared to the state-of-the-art methods. Under the two feature sets, MambaMIL outperforms the second-best performance method by 2.6% and 2.7% on mean performance across all seven datasets.

Survival Prediction

Cancer Subtyping

We perform comparative experiments on two public challenging datasets: BRACS and NSCLC. To ensure the robust evaluation of comparison experiments, we employ 10-fold Monte Carlo cross-validation, which partitions the data into training, validation, and testing sets with a ratio of 8:1:1. Additionally, for fair comparisons with existing methods, we also perform experiments on the official split of the BRACS dataset, marked as * in Table 2. Following the standard setting, we adopt the Area Under Curve (AUC) and Accuracy (ACC) metrics along with their standard deviation (std) for evaluation, which provides a reliable assessment which is less sensitive to class imbalance.

Table 2 shows experimental results on two datasets, encompassing both binary and multiple classification tasks. Compared to the state-of-the-art, our proposed MambaMIL demonstrates outstanding performance, attaining an AUC of 80.4% on the BRACS dataset and 95.9% on the NSCLC dataset. Notably, MambaMIL employs the same aggregation module as ABMIL but significantly outperforms it, with significant improvements of 3.9% and 2.1% in terms of AUC for BRACS and NSCLC datasets, respectively.

Cancer Subtyping

Ablation Study

Ablation Study

To assess the effectiveness of SR-Mamba, we conduct extensive experiments to compare the performance of different variations of Mamba block: the vanilla Mamba, Bidirectional Mamba (BiMamba) and Our SR-Mamba, on survival prediction datasets. For a fair comparison of each specific dataset, we utilize the same setting to train these variants. As shown in Table 3, SR-Mamba surpasses the performance of Mamba and Bi-Mamba, which demonstrates the effectiveness of sequence reordering. Meanwhile, overfitting poses a substantial challenge in applying MIL methods for WSI analysis, especially for transformer-based methods like TransMIL. As illustrated in Fig. 3, during the training process, TransMIL displays clear signs of overfitting on the validation set, characterized by a significant increase in validation loss alongside decreases in both the ACC and the AUC metrics. In contrast, MambaMIL exhibits stable performance across the evaluation period, showcasing its strong ability to alleviate overfitting. This capability originates from the more discriminative representations extracted from various sequence orderings, akin to the effects of data augmentation, which significantly enhances model robustness.

Overfitting

Conclusion

In this paper, we introduce a novel Mamba-based MIL method (MambaMIL) to tackle the challenges associated with long sequence modeling and overfitting, marking the first application of the Mamba framework in computational pathology. Our approach, based on the specially designed Sequence Reordering Mamba module, enables the effective leveraging of intrinsic global information contained within the long sequences of instances. The experimental results on nine benchmarks demonstrate that MambaMIL benefits from long sequence modeling and outperforms existing competitors. Given the excellent performance of MambaMIL, we anticipate its application can be extended to other modalities in computational pathology, including genomics, pathology reports, and clinical data. This expansion would enable the leveraging of multi-modal information for effective and accurate diagnosis, prognosis, and therapeutic-response prediction.