Skip to content

SmTransformerABMIL

torchmil.models.SmTransformerABMIL

Bases: MILModel

Transformer Attention-based Multiple Instance Learning model with the \(\texttt{Sm}\) operator. Proposed in Sm: enhanced localization in Multiple Instance Learning for medical imaging classification.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, it transforms the instance features using a transformer encoder with the \(\texttt{Sm}\) operator,

\[ \mathbf{X} = \text{SmTransformerEncoder}(\mathbf{X}) \in \mathbb{R}^{N \times D}. \]

Subsequently, it aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{D}\) using an attention-based pooling mechanism that incorporates the \(\texttt{Sm}\) operator,

\[ \mathbf{z}, \mathbf{f} = \operatorname{SmAttentionPool}(\mathbf{X}). \]

where \(\mathbf{f} \in \mathbb{R}^{N}\) are the attention values. Finally, the bag representation \(\mathbf{z}\) is then fed into a classifier (one linear layer) to predict the bag label.

See SmAttentionPool for more details on the attention-based pooling, and SmTransformerEncoder for more details on the transformer encoder.

__init__(in_shape, pool_att_dim=128, pool_act='tanh', pool_sm_mode='approx', pool_sm_alpha='trainable', pool_sm_layers=1, pool_sm_steps=10, pool_sm_pre=False, pool_sm_post=False, pool_sm_spectral_norm=False, feat_ext=torch.nn.Identity(), transf_att_dim=512, transf_n_layers=1, transf_n_heads=4, transf_use_mlp=True, transf_add_self=True, transf_dropout=0.0, transf_sm_alpha='trainable', transf_sm_mode='approx', transf_sm_steps=10, criterion=torch.nn.BCEWithLogitsLoss())

Class constructor.

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • pool_att_dim (int, default: 128 ) –

    Attention dimension for pooling.

  • pool_act (str, default: 'tanh' ) –

    Activation function for pooling. Possible values: 'tanh', 'relu', 'gelu'.

  • pool_sm_mode (str, default: 'approx' ) –

    Mode for the Sm operator in pooling. Possible values: 'approx', 'exact'.

  • pool_sm_alpha (Union[float, str], default: 'trainable' ) –

    Alpha value for the Sm operator in pooling. If 'trainable', alpha is trainable.

  • pool_sm_layers (int, default: 1 ) –

    Number of layers that use the Sm operator in pooling.

  • pool_sm_steps (int, default: 10 ) –

    Number of steps for the Sm operator in pooling.

  • pool_sm_pre (bool, default: False ) –

    If True, apply Sm operator before the attention pooling.

  • pool_sm_post (bool, default: False ) –

    If True, apply Sm operator after the attention pooling.

  • pool_sm_spectral_norm (bool, default: False ) –

    If True, apply spectral normalization to all linear layers in pooling.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • transf_att_dim (int, default: 512 ) –

    Attention dimension for transformer encoder.

  • transf_n_layers (int, default: 1 ) –

    Number of layers in transformer encoder.

  • transf_n_heads (int, default: 4 ) –

    Number of heads in transformer encoder.

  • transf_use_mlp (bool, default: True ) –

    Whether to use MLP in transformer encoder.

  • transf_add_self (bool, default: True ) –

    Whether to add input to output in transformer encoder.

  • transf_dropout (float, default: 0.0 ) –

    Dropout rate in transformer encoder.

  • transf_sm_alpha (float, default: 'trainable' ) –

    Alpha value for the Sm operator in transformer encoder.

  • transf_sm_mode (str, default: 'approx' ) –

    Mode for the Sm operator in transformer encoder.

  • transf_sm_steps (int, default: 10 ) –

    Number of steps for the Sm operator in transformer encoder.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.

forward(X, adj, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

compute_loss(Y, X, adj, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, adj, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).