One paper was accepted in ICML 2024

Our paper “Post-hoc Part-prototype Networks” was accepted on ICML 2024. Congratulations to the authors!

Introduction

In the past years, deep learning based models such as Convolutional Neural Networks (CNN) and Vision Transformer (ViT) have achieved impressive performance in computer vision tasks. However, the interpretability of these black-box models has always been a major concern and limits their real-world deployments. Existing interpretability methods can mainly be categorized to post-hoc and ante-hoc methods. Post-hoc methods indicate methods explaining a trained model and ante-hoc methods aim to design “interpretable-by-design” or “inherently interpretable” models (e.g., part-prototype networks).

Motivation

Post-hoc Part-prototype Networks

Post-hoc explainability methods such as Grad-CAM are popular because they do not influence the performance of a trained model. However, they mainly reveal where a model looks at for a given input, fail to explain what the model looks for (e.g., what is important to classify a bird image to a Scott Oriole?). Existing part-prototype networks leverage part-prototypes (e.g., characteristic Scott Oriole’s wing and head) to answer both where and what, but often under-perform their black box counterparts in the accuracy. Therefore, a natural question is: can one construct a network that answers both where and what in a post-hoc manner to guarantee the model’s performance?

Method

Post-hoc Part-prototype Networks

We propose the first post-hoc part-prototype network via decomposing the classification head of a trained model into a set of interpretable part-prototypes. Concretely, we propose an unsupervised prototype discovery and refining strategy to obtain prototypes that can precisely reconstruct the classification head, yet being interpretable.

Initial prototype discovery

This is achieved via applying None-negative Matrix Factorization (NMF) to optimize the following objective function:

\(\begin{equation} \label{eqa:feature_err} \min_{\mathbf{E, P}} ||\mathbf{F}-\mathbf{EP}||_2^2, \end{equation}\) where \(\mathbf{P} \in \mathbb{R}^{k\times D}\) are the \(k\) initial prototypes we find and \(\mathbf{E}\in \mathbb{R}^{nHW\times k}\) indicates how strong each prototype contributes to the reconstruction of features in all \(HW\) spatial positions of all \(n\) image feature maps.

The choice of using NMF is motivated by the fact that it only allows additive, and not subtractive combinations of components, removing complex cancellations of positive and negative values in \(\mathbf{EP}\). Such a constraint on the sign of the decomposed matrices is proved to lead to sparsity and also parts-based representations. It is therefore suitable to discover part-prototypes to indicate individual components (e.g., a Scott Oriole’s head, wing, belly) of the target object (e.g., a Scott Oriole bird) under the part-prototype network framework. Popular matrix decomposition technique such as Principle Component Analysis does not have this capability.

Prototype scaling

In this step, we first try to approximately reconstruct the trained classification head \(\mathbf{v} \in \mathbb{R}^{1\times D}\) via initially discovered part-prototypes \(\mathbf{P}\). We propose the following convex optimization: \(\begin{equation} \min_{\alpha_i} ||\mathbf{v} - \sum_{i=1}^{k}\alpha_i\mathbf{p}_i||_2^2, \end{equation}\) where \(\alpha_i\) is the coefficient to scale the initial prototypes with respect to the class \(c\) and \(\mathbf{p}_i\) is a row vector of \(\mathbf{P}\). Since \(\mathbf{v}\) and \(\mathbf{p}_i\) are fixed, this optimization is convex and \(\alpha_i\) has a global optimum. This optimization tries to express the classification head (or target object) as a combination of initially discovered part-prototypes. However, due to the typically small number of part-prototypes \(k\), there is no guarantee that a linear combination of the basis consisting of these \(k\) initially discovered prototypes can precisely reconstruct the trained high-dimensional (e.g., 512 in ResNet34) classification head. Thus we further propose to refine these prototypes to fully recover the model performance.

Prototype refinement

We propose to refine the prototypes via distributing the residual parameters. This step aims to equip interpretable prototypes with better class-discriminative power and thus fully recover the model’s performance.

After the coefficients \(\alpha_i\) are already optimized, the residual parameters \(\mathbf{R} \in \mathbb{R}^{1\times D}\) are calculated as: \(\begin{equation} \mathbf{R} = \mathbf{v} - \sum_{i=1}^k \alpha_i \mathbf{p}_i. \end{equation}\)

In order to completely decompose the classification head into \(k\) interpretable prototypes, we suggest that the parameters in \(\mathbf{R}\) should be fully absorbed by interpretable prototypes, while not sacrificing the prototypes’ interpretability. Thus we propose to decompose the residual parameters into

\[\begin{equation} \mathbf{R} = \mathbf{r}_1 + \mathbf{r}_2+...+ \mathbf{r}_k \end{equation}\]

and distribute above \(k\) components \(\mathbf{r_i} \in \mathbb{R}^{1\times D}\) to refine all \(k\) prototypes respectively. This means shifting the part-prototypes to a more class-discriminative, yet semantically meaningful areas in the feature space. We propose the following dynamic optimization based prototype refinement strategy with interpretability constraints:

\[\begin{equation} \begin{aligned} \begin{gathered} \min_{\mathbf{r_i}} \sum_{i=1}^{k}||\mathrm{Norm}(\mathbf{F}\mathbf{p}_i^T)-\mathrm{Norm}(\mathbf{F}\mathbf{r}_i^T)||_2^2 \\ s.t. \sum_{i=1}^k \mathbf{r}_i=\mathbf{R} . \end{gathered} \end{aligned} \end{equation}\]

The \(\mathrm{Norm}\) indicates the normalization operation along the spatial dimension of images (e.g., the first dimension of \(\mathbf{Fp}_i^T\in \mathbb{R}^{nHW\times1}\)). We try to align the distribution after the normalization because human perceives heatmap images according to the relative value distribution in the heatmaps to observe the highlighted areas. This constraint encourages the visualization maps generated by \(\mathbf{r}_i\) with respect to \(\mathbf{F}\) to be perceptually close to that of the \(\mathbf{p}_i\) to maintain the interpretability of \(\mathbf{p}_i\) after the refinement.

After completing all above steps, the classification head for class \(c\) is finally completely decomposed into \(k\) interpretable prototypes \(\mathbf{\tilde{p_i}}\):

\[\begin{equation} \mathbf{v} = \sum_{i=1}^k \tilde{\mathbf{p}}_i, \quad \tilde{\mathbf{p}}_i = \alpha_i\mathbf{p}_i+\mathbf{r}_i. \end{equation}\]

Repeating above steps for all classes could decompose the whole trained classification head of a model.

Experimental results

We show that our method is more faithful when explaining where the model looks at and yields prototypes of better quality when explaining what the model looks for.

Evaluating “where” the model looks at

Sensitivity: The attribution should be zero if the prediction does not depend on a certain feature and non-zero otherwise.

Completeness: The sum of contributions of all features equals the prediction.

Linearity: If a model is a linear combination of multiple models, the attributions of it should have the same linearity.

Symmetry preserving: If two features play the exactly same role in the network, they should receive the same attribution.

Post-hoc Part-prototype Networks

Evaluating “what” the model looks for

The quality of the prototype is evaluated by the consistency and stability scores. The consistency score calculates the percentage of all learned prototypes that could consistently activate the semantically same areas (e.g., a bird head) across different images using keypoint annotations offered by the dataset. The stability score calculates the percentage of prototypes that activate the same areas after the images are disturbed by noises. We use the base version of ProtoEval for a fair comparison with other prior works. The table shows that in Res34, Res152, VGG19, Dense121, our method outperforms prior works in the consistency scores by +9.4 %, +9.3%, +5.1%, +3.8%, respectively. In Res34, Res152, Dense121, Dense161 backbones, our stability score outperforms prior works by an even larger margin (+12%, +11%, +23.7%, +19.2%). Such simply trained models with no special design easily suppresses nearly all prior works in the prediction accuracy.

Post-hoc Part-prototype Networks

Conclusion

In this work, we aim to answer: can we construct a model that explains both where the model looks at and what the model looks for in a post-hoc manner to guarantee the performance? To this end, we propose the first post-hoc part-prototype network via a novel prototype discovery and refinement strategy which precisely reconstructs the trained classification head. Besides guaranteeing the performance, our method is even more faithful when explaining where and yields more consistent and stable prototypes when explaining what compared to prior works. This work indicates the value of exploring existing feature space for interpretability and opens up a new paradigm to design part-prototype networks in a post-hoc manner.

References

Andong Tan, Fengtao Zhou, and Hao Chen. “Post-hoc Part-prototype Networks”. ICML 2024