Skip to content

ProbSmoothABMIL

torchmil.models.ProbSmoothABMIL

Bases: MILModel

Attention-based Multiple Instance Learning (ABMIL) model with Probabilistic Smooth Attention Pooling. Proposed in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging and Smooth Attention for Deep Multiple Instance Learning: Application to CT Intracranial Hemorrhage Detection

Overview. This model extends the ABMIL model by incorporating a probabilistic pooling mechanism.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Subsequently, it aggregates the instance features into a bag representation using a probabilistic attention-based pooling mechanism, as detailed in ProbSmoothAttentionPool.

Specifically, it computes a mean vector \(\mathbf{\mu}_{\mathbf{f}} \in \mathbb{R}^N\) and a variance vector \(\mathbf{\sigma}_{\mathbf{f}^2} \in \mathbb{R}^N\) that define the attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\),

\[ \mathbf{\mu}_{\mathbf{f}}, \mathbf{\sigma}_{\mathbf{f}} = \operatorname{ProbSmoothAttentionPool}(\mathbf{X}). \]

If covar_mode='zero', the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.

Then, \(m\) attention vectors \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(m)} \right]^\top \in \mathbb{R}^{m \times N}\) are sampled from the attention distribution. The bag representation \(\widehat{\mathbf{z}} \in \mathbb{R}^{m \times D}\) is then computed as:

\[ \widehat{\mathbf{z}} = \operatorname{Softmax}(\widehat{\mathbf{F}}) \mathbf{X}. \]

The bag representation \(\widehat{\mathbf{z}}\) is fed into a classifier, implemented as a linear layer, to produce bag label predictions \(Y_{\text{pred}} \in \mathbb{R}^{m}\).

Notably, the attention distribution naturally induces a distribution over the bag label predictions. This model thus generates multiple predictions for each bag, corresponding to different samples from this distribution.

Regularization. The probabilistic pooling mechanism introduces a regularization term to the loss function that encourages smoothness in the attention values. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the regularization term corresponds to

\[ \ell_{\text{KL}}(\mathbf{X}, \mathbf{A}) = \begin{cases} \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} \quad & \text{if } \texttt{covar_mode='zero'}, \\ \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} + \operatorname{Tr}(\mathbf{L} \mathbf{\Sigma}_{\mathbf{f}}) - \frac{1}{2}\log \det( \mathbf{\Sigma}_{\mathbf{f}} ) + \operatorname{const} \quad & \text{if } \texttt{covar_mode='diag'}, \\ \end{cases} \]

where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\). This term is then averaged for all bags in the batch and added to the loss function.

__init__(in_shape=None, att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • covar_mode (str, default: 'diag' ) –

    Covariance mode for the Gaussian prior. Possible values: 'diag', 'full'.

  • n_samples_train (int, default: 1000 ) –

    Number of samples for training.

  • n_samples_test (int, default: 5000 ) –

    Number of samples for testing.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.

forward(X, adj=None, mask=None, return_att=False, return_samples=False, return_kl_div=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • adj (Tensor, default: None ) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size). Only required when return_kl_div=True.

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_samples (bool, default: False ) –

    If True and return_att=True, the attention values returned are samples from the attention distribution.

  • return_kl_div (bool, default: False ) –

    If True, returns the KL divergence between the attention distribution and the prior distribution.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size, n_samples) if return_samples=True, else (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size, n_samples) if return_samples=True, else (batch_size, bag_size).

  • kl_div ( Tensor ) –

    Only returned when return_kl_div=True. KL divergence between the attention distribution and the prior distribution of shape ().

compute_loss(Y, X, adj, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value and the KL divergence between the attention distribution and the prior distribution.

predict(X, mask=None, return_inst_pred=True, return_samples=False)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns the attention values as instance labels predictions, in addition to bag label predictions.

  • return_samples (bool, default: False ) –

    If True and return_inst_pred=True, the instance label predictions returned are samples from the instance label distribution.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    Only returned when return_inst_pred=True. Attention values (before normalization) of shape (batch_size, bag_size) if return_samples=False, else (batch_size, bag_size, n_samples).