Skip to content

DSMIL

torchmil.models.DSMIL

Bases: MILModel

Dual-stream Multiple Instance Learning (DSMIL) model, proposed in the paper Dual-stream Multiple Instance Learning Network for Whole Slide Image Classification with Self-supervised Contrastive Learning.

Overview. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, two streams are used. The first stream uses an instance classifier \(c \ \colon \mathbb{R}^D \to \mathbb{R}\) (implemented as a linear layer) and retrieves the instance with the highest logit score,

\[ m = \arg \max \{ c(\mathbf{x}_1), \ldots, c(\mathbf{x}_N) \}. \]

Then, the second stream computes the bag representation \(\mathbf{z} \in \mathbb{R}^D\) as

\[ \mathbf{z} = \frac{ \exp \left( \mathbf{q}_i^\top \mathbf{q}_m \right)}{\sum_{k=1}^N \exp \left( \mathbf{q}_k^\top \mathbf{q}_m \right)} \mathbf{v}_i, \]

where \(\mathbf{q}_i = \mathbf{W}_q \mathbf{x}_i\) and \(\mathbf{v}_i = \mathbf{W}_v \mathbf{x}_i\). This is similar to self-attention with the difference that query-key matching is performed only with the critical instance.

Finally, the bag representation is used to predict the bag label using a bag classifier implemented as a linear layer.

Loss function. By default, the model is trained end-to-end using the followind per-bag loss:

\[ \ell = \ell_{\text{BCE}}(Y, \hat{Y}) + \ell_{\text{BCE}}(Y, c(\mathbf{x}_m)), \]

where \(\ell_{\text{BCE}}\) is the Binary Cross-Entropy loss, \(Y\) is the true bag label, \(\hat{Y}\) is the predicted bag label, and \(c(\mathbf{x}_m)\) is the predicted label of the critical instance.

__init__(in_shape=None, att_dim=128, nonlinear_q=False, nonlinear_v=False, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • nonlinear_q (bool, default: False ) –

    If True, apply nonlinearity to the query.

  • nonlinear_v (bool, default: False ) –

    If True, apply nonlinearity to the value.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, mask=None, return_att=False, return_inst_pred=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

  • return_inst_pred (bool, default: False ) –

    If True, returns instance label logits in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

  • y_pred ( tuple[Tensor, Tensor] ) –

    Only returned when return_inst_pred=True. Instance label logits of shape (batch_size, bag_size).

compute_loss(Y, X, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, mask=None, return_inst_pred=False)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: False ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).