Skip to content

CLAM

torchmil.models.CLAM_SB

Bases: MILModel

Clustering-constrained Attention Multiple Instance Learning (CLAM), proposed in the paper Data Efficient and Weakly Supervised Computational Pathology on Whole Slide Images.

Overview. The forward pass of CLAM is identical to the forward pass of ABMIL. The difference lies in the instance-level regularization, which we describe below.

Instance-level regularization. CLAM uses a binary clustering objective during training. For this, in the binary MIL setting, two clustering classifiers are considered: \(c_0 \colon \mathbb{R}^D \to \mathbb{R}\) and \(c_1 \colon \mathbb{R}^D \to \mathbb{R}\). To supervise this objective, the attention values computes by the attention pooling are used to generate pseudo labels.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with label \(Y\) and attention values \(\mathbf{f} = \left[ f_1, \ldots, f_N \right]^\top \in \mathbb{R}^{N}\), the instance-level regularization is performed as follows:

  1. The \(k\) instances with the highest attention values are selected as in-the-class instances. The \(k\) instances with the lowest attention values are selected as out-of-the-class instances,

    \[\begin{gather} D_{\text{in}} = \left\{ \mathbf{x}_i \mid f_i \in \text{TopK}(\mathbf{f}, k) \right\}, \\ D_{\text{out}} = \left\{ \mathbf{x}_i \mid f_i \in \text{BottomK}(\mathbf{f}, k) \right\}. \end{gather}\]
  2. The instances in \(D_{\text{in}}\) are assigned a pseudo label of 1 for \(c_Y\), and a pseudo label of 0 for \(c_{1-Y}\). The instances in \(D_{\text{out}}\) are assigned a pseudo label of 0 for \(c_Y\). The pseudo labels are used to train the clustering classifiers,

    \[\begin{gather} \ell_{\text{in}} = \frac{1}{2K} \left( \sum_{\mathbf{x} \in D_{\text{in}}}\ell_{\text{inst}}(c_Y(\mathbf{x}), 1) + \sum_{\mathbf{x} \in D_{\text{out}}}\ell_{\text{inst}}(c_{Y}(\mathbf{x}), 0) \right), \\ \ell_{\text{out}} = \frac{1}{K} \sum_{\mathbf{x} \in D_{\text{in}}}\ell_{\text{inst}}(c_{1-Y}(\mathbf{x}), 0), \end{gather}\]

where \(\ell_{\text{inst}}\) is the instance-level loss function (the default is SmoothTop1SVM) and \(Y\) is the true bag label. The total instance-level loss is \(\ell_{\text{in}} + \ell_{\text{out}}\), which is added to the bag-level loss to train the model.

__init__(in_shape=None, att_dim=128, att_act='tanh', k_sample=10, gated=False, inst_loss_name='SmoothTop1SVM', feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • att_act (str, default: 'tanh' ) –

    Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.

  • k_sample (int, default: 10 ) –

    Number of instances to sample.

  • gated (bool, default: False ) –

    If True, use gated attention in the attention pooling.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, mask=None, return_att=False, return_emb=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_emb (bool, default: False ) –

    If True, returns embeddings in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

  • emb ( Tensor ) –

    Only returned when return_emb=True. Embeddings of shape (batch_size, bag_size, feat_dim).

compute_loss(Y, X, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, mask=None, return_inst_pred=True)

Predict bag labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Predicted bag labels of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    Predicted instance labels of shape (batch_size, bag_size). Only returned when return_inst_pred=True.