Skip to content

IIBMIL

torchmil.models.IIBMIL

Bases: Module

Integrated Instance-Level and Bag-Level Multiple Instance Learning (IIB-MIL) model, proposed in the paper IIB-MIL: Integrated Instance-Level and Bag-Level Multiple Instances Learning with Label Disambiguation for Pathological Image Analysis.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, a TransformerEncoder is applied to transform the instance features using context information. Subsequently, the model uses bag-level and instance-level supervision:

Bag-level supervision: The instances are aggregated into a class token using a transformer decoder. A linear layer is then applied to predict the bag label.

Instance-level supervision: Consists of four steps.

  1. Using an instance classifier, obtain the probability of instance \(i\) belonging to class \(c\), denoted as \(p_{i,c}\).
  2. The prototype \(\mathbf{p}_{c,t} \in \mathbf{R}^{D}\) of class \(c\) at time \(t\) is updated using a momentum update rule based on the set of instances with the top \(k\) highest probabilities of belonging to class \(c\). Writing \(\mathbf{P}_t = \left[ \mathbf{p}_{1,t}, \ldots, \mathbf{p}_{C,t} \right]^\top \in \mathbb{R}^{C \times D}\), the prototype label \(z_{i}\) of each instance is obtained as \(z_{i} = \text{argmax}_{c} \ \mathbf{P} \mathbf{x}_i\).
  3. Compute instance-level soft labels using the prototype labels and a momentum update.
  4. Compute the instance-level cross-entropy loss using the soft labels and the instance classifier.
__init__(in_shape=None, att_dim=256, n_layers_encoder=1, n_layers_decoder=1, use_mlp_encoder=True, use_mlp_decoder=False, n_heads=4, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • att_dim (int, default: 256 ) –

    Attention dimension.

  • n_layers_encoder (int, default: 1 ) –

    Number of layers in the transformer encoder.

  • n_layers_decoder (int, default: 1 ) –

    Number of layers in the transformer decoder.

  • use_mlp_encoder (bool, default: True ) –

    If True, uses a multi-layer perceptron (MLP) in the encoder.

  • use_mlp_decoder (bool, default: False ) –

    If True, uses a multi-layer perceptron (MLP) in the decoder.

  • n_heads (int, default: 4 ) –

    Number of attention heads.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, mask=None, return_inst_pred=False, return_X_enc=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_X_enc (bool, default: False ) –

    If True, returns instance embeddings in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    Only returned when return_inst_pred=True. Instance label logits of shape (batch_size, bag_size).

  • X_enc ( Tensor ) –

    Only returned when return_X_enc=True. Instance embeddings of shape (batch_size, bag_size, att_dim).

compute_loss(Y, X, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).

update_prototypes(X, mask=None, proto_m=0.9)

Update prototypes.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • proto_m (float, default: 0.9 ) –

    Momentum for updating prototypes

Returns:

  • None

    None