Skip to content

TransformerProbSmoothABMIL

torchmil.models.TransformerProbSmoothABMIL

Bases: MILModel

Transformer Attention-based Multiple Instance Learning model, with probabilistic attention-based pooling. Proposed in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\). Then, it transforms the instance features using a transformer encoder,

\[ \mathbf{X} = \text{TransformerEncoder}(\mathbf{X}) \in \mathbb{R}^{N \times D}. \]

Subsequently, it aggregates the instance features into a bag representation using a probabilistic attention-based pooling mechanism, as detailed in ProbSmoothAttentionPool.

Specifically, it computes a mean vector \(\mathbf{\mu}_{\mathbf{f}} \in \mathbb{R}^N\) and a variance vector \(\mathbf{\sigma}_{\mathbf{f}^2} \in \mathbb{R}^N\) that define the attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\),

\[ \mathbf{\mu}_{\mathbf{f}}, \mathbf{\sigma}_{\mathbf{f}} = \operatorname{ProbSmoothAttentionPool}(\mathbf{X}). \]

If covar_mode='zero', the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.

Then, \(m\) attention vectors \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(m)} \right]^\top \in \mathbb{R}^{m \times N}\) are sampled from the attention distribution. The bag representation \(\widehat{\mathbf{z}} \in \mathbb{R}^{m \times D}\) is then computed as:

\[ \widehat{\mathbf{z}} = \operatorname{Softmax}(\widehat{\mathbf{F}}) \mathbf{X}. \]

The bag representation \(\widehat{\mathbf{z}}\) is fed into a classifier, implemented as a linear layer, to produce bag label predictions \(Y_{\text{pred}} \in \mathbb{R}^{m}\).

Notably, the attention distribution naturally induces a distribution over the bag label predictions. This model thus generates multiple predictions for each bag, corresponding to different samples from this distribution.

Regularization. The probabilistic pooling mechanism introduces a regularization term to the loss function that encourages smoothness in the attention values. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the regularization term corresponds to

\[ \ell_{\text{KL}}(\mathbf{X}, \mathbf{A}) = \begin{cases} \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} \quad & \text{if } \texttt{covar_mode='zero'}, \\ \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} + \operatorname{Tr}(\mathbf{L} \mathbf{\Sigma}_{\mathbf{f}}) - \frac{1}{2}\log \det( \mathbf{\Sigma}_{\mathbf{f}} ) + \operatorname{const} \quad & \text{if } \texttt{covar_mode='diag'}, \\ \end{cases} \]

where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\). This term is then averaged for all bags in the batch and added to the loss function.

__init__(in_shape=None, pool_att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000, feat_ext=torch.nn.Identity(), transf_att_dim=512, transf_n_layers=1, transf_n_heads=8, transf_use_mlp=True, transf_add_self=True, transf_dropout=0.0, criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • pool_att_dim (int, default: 128 ) –

    Attention dimension.

  • covar_mode (str, default: 'diag' ) –

    Covariance mode for the Gaussian prior. Possible values: 'diag', 'full'.

  • n_samples_train (int, default: 1000 ) –

    Number of samples for training.

  • n_samples_test (int, default: 5000 ) –

    Number of samples for testing.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • transf_att_dim (int, default: 512 ) –

    Attention dimension for transformer encoder.

  • transf_n_layers (int, default: 1 ) –

    Number of layers in transformer encoder.

  • transf_n_heads (int, default: 8 ) –

    Number of heads in transformer encoder.

  • transf_use_mlp (bool, default: True ) –

    Whether to use MLP in transformer encoder.

  • transf_add_self (bool, default: True ) –

    Whether to add input to output in transformer encoder.

  • transf_dropout (float, default: 0.0 ) –

    Dropout rate in transformer encoder.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.

forward(X, adj, mask=None, return_att=False, return_samples=False, return_kl_div=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size). Only required when return_kl_div=True.

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_samples (bool, default: False ) –

    If True and return_att=True, the attention values returned are samples from the attention distribution.

  • return_kl_div (bool, default: False ) –

    If True, returns the KL divergence between the attention distribution and the prior distribution.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size, n_samples) if return_samples=True, else (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size, n_samples) if return_samples=True, else (batch_size, bag_size).

  • kl_div ( Tensor ) –

    Only returned when return_kl_div=True. KL divergence between the attention distribution and the prior distribution of shape ().

compute_loss(Y, X, adj, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value and the KL divergence between the attention distribution and the prior distribution.

predict(X, adj, mask=None, return_inst_pred=True, return_samples=False)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns the attention values as instance labels predictions, in addition to bag label predictions.

  • return_samples (bool, default: False ) –

    If True and return_inst_pred=True, the instance label predictions returned are samples from the instance label distribution.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    Only returned when return_inst_pred=True. Attention values (before normalization) of shape (batch_size, bag_size) if return_samples=False, else (batch_size, bag_size, n_samples).