Skip to content

TransMIL

torchmil.models.TransMIL

Bases: MILModel

Method proposed in the paper TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, following Algorithm 2 in the paper, it performs sequence squaring, adds a class token, and applies the novel TPT module. This module consists of two Nyströmformer layers and the novel PPEG (Pyramid Positional Encoding Generator) layer.

Finally, a linear classifier is used to predict the bag label from the class token.

__init__(in_shape, att_dim=512, n_layers=2, n_heads=4, n_landmarks=None, pinv_iterations=6, dropout=0.0, use_mlp=False, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • att_dim (int, default: 512 ) –

    Embedding dimension. Should be divisible by n_heads.

  • n_layers (int, default: 2 ) –

    Number of Nyströmformer layers.

  • n_heads (int, default: 4 ) –

    Number of heads in the Nyströmformer layer.

  • n_landmarks (int, default: None ) –

    Number of landmarks in the Nyströmformer layer.

  • pinv_iterations (int, default: 6 ) –

    Number of iterations for the pseudo-inverse in the Nyströmformer layer.

  • dropout (float, default: 0.0 ) –

    Dropout rate in the Nyströmformer layer.

  • use_mlp (bool, default: False ) –

    Whether to use a MLP after the Nyströmformer layer.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor. By default, the identity function (no feature extraction).

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • return_att (bool, default: False ) –

    Whether to return the attention values.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values of shape (batch_size, bag_size).

compute_loss(Y, X)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values of shape (batch_size, bag_size).