Skip to content

SETMIL

torchmil.models.SETMIL

Bases: MILModel

SETMIL: Spatial Encoding Transformer-Based Multiple Instance Learning for Pathological Image Analysis (SETMIL) model, proposed in the paper SETMIL: Spatial Encoding Transformer-Based Multiple Instance Learning for Pathological Image Analysis.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, the Pyramid Multi-Scale Fusion (PMF) module enriches the representation with multi-scale context information. The PMF module consists of three T2T modules with different kernel sizes, \(k = 3, 5, 7\), concatenated along the feature dimension,

\[\operatorname{PMF}\left( \mathbf{X} \right) = \text{Concat}(\text{T2T}_{k=3}(\mathbf{X}), \text{T2T}_{k=5}(\mathbf{X}), \text{T2T}_{k=7}(\mathbf{X})).\]

See T2T and T2TLayer for further information.

Then, the model applies a Spatial Encoding Transformer (SET), which consists of a stack of transformer layers with image Relative Positional Encoding (iRPE). See iRPETransformer for further information.

Finally, using the class token computed by the SET module, the model predicts the bag label \(\hat{Y}\) using a linear layer.

Note. When use_pmf=True, the input bag is reshaped to a square shape, and the PMF module is applied. This modifies the bag structure unreversibly, and thus attention values cannot be computed. If return_att=True, the attention values will be set to zeros.

__init__(in_shape, att_dim=512, use_pmf=False, pmf_n_heads=4, pmf_use_mlp=True, pmf_dropout=0.0, pmf_kernel_list=[(3, 3), (5, 5), (7, 7)], pmf_stride_list=[(1, 1), (1, 1), (1, 1)], pmf_padding_list=[(1, 1), (2, 2), (3, 3)], pmf_dilation_list=[(1, 1), (1, 1), (1, 1)], set_n_layers=1, set_n_heads=4, set_use_mlp=True, set_dropout=0.0, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k', feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • att_dim (int, default: 512 ) –

    Attention dimension used by the PMF and SET modules.

  • use_pmf (bool, default: False ) –

    If True, use Pyramid Multihead Feature (PMF) before the SET module.

  • pmf_n_heads (int, default: 4 ) –

    Number of heads in the PMF module.

  • pmf_use_mlp (bool, default: True ) –

    If True, use MLP in the PMF module.

  • pmf_dropout (float, default: 0.0 ) –

    Dropout rate in the PMF module.

  • pmf_kernel_list (list[tuple[int, int]], default: [(3, 3), (5, 5), (7, 7)] ) –

    List of kernel sizes in the PMF module.

  • pmf_stride_list (list[tuple[int, int]], default: [(1, 1), (1, 1), (1, 1)] ) –

    List of stride sizes in the PMF module.

  • pmf_padding_list (list[tuple[int, int]], default: [(1, 1), (2, 2), (3, 3)] ) –

    List of padding sizes in the PMF module.

  • pmf_dilation_list (list[tuple[int, int]], default: [(1, 1), (1, 1), (1, 1)] ) –

    List of dilation sizes in the PMF module.

  • set_n_layers (int, default: 1 ) –

    Number of layers in the SET module.

  • set_n_heads (int, default: 4 ) –

    Number of heads in the SET module.

  • set_use_mlp (bool, default: True ) –

    If True, use MLP in the SET module.

  • set_dropout (float, default: 0.0 ) –

    Dropout rate in the SET module.

  • rpe_ratio (float, default: 1.9 ) –

    Ratio for relative positional encoding.

  • rpe_method (str, default: 'product' ) –

    Method for relative positional encoding. Possible values: ['euc', 'quant', 'cross', 'product']

  • rpe_mode (str, default: 'contextual' ) –

    Mode for relative positional encoding. Possible values: [None, 'bias', 'contextual']

  • rpe_shared_head (bool, default: True ) –

    If True, share weights across different heads.

  • rpe_skip (int, default: 1 ) –

    Number of tokens to skip in the relative positional encoding. Possible values: [0, 1].

  • rpe_on (str, default: 'k' ) –

    Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, coords, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, feat_dim).

  • coords (Tensor) –

    Coordinates of shape (batch_size, bag_size, coord_dim).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

compute_loss(Y, X, coords)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • coords (Tensor) –

    Coordinates of shape (batch_size, bag_size, coord_dim).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, coords, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • coords (Tensor) –

    Coordinates of shape (batch_size, bag_size, coord_dim).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).