Skip to content

DFTDMIL

torchmil.models.DTFDMIL

Bases: MILModel

Double-Tier Feature Distillation Multiple Instance Learning (DFTD-MIL) model, proposed in the paper DTFD-MIL: Double-Tier Feature Distillation Multiple Instance Learning for Histopathology Whole Slide Image Classification.

Overview. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, the instances in a bag are randomly grouped in \(M\) pseudo-bags \(\{\mathbf{X}_1, \cdots, \mathbf{X}_M\}\) with approximately the same number of instances. Each pseudo-bag is assigned its parent's bag label \(Y_m = Y\). Then, the model has two prediction tiers:

In Tier 1, the model uses the attention pool (see AttentionPool for details) and a classifier, jointly noted as \(T_1\) to predict the label of each pseudo-bag,

\[ \widehat{Y}_m = T_1(\mathbf{X}_m).\]

The loss associated to this tier is the binary cross entropy computed using the pseudo-bag labels \(Y_m\) and the predicted label \(\widehat{Y}_m\).

In Tier 2, Grad-CAM (see Grad-CAM for details) is used to compute the probability of each instance. Based on that probability, a feature vector \(\mathbf{z}^m\) is distilled for the \(m\)-th pseudo-bag. Then, the model uses another attention pool and a classifier, jointly noted as \(T_2\) to predict the final label of the bag,

\[ \widehat{Y} = T_2\left( \left[ \mathbf{z}_1, \ldots, \mathbf{z}_M \right]^\top \right).\]

The loss associated to this tier is the binary cross entropy computed using the bag labels \(Y\) and the predicted label \(\widehat{Y}\).

Loss function. By default, the model is trained end-to-end using the followind per-bag loss:

\[ \ell = \ell_{\text{BCE}}(Y, \widehat{Y}) + \frac{1}{M} \sum_{m=1}^{M} \ell_{\text{BCE}}(Y_m, \widehat{Y}_m),\]

where \(\ell_{\text{BCE}}\) is the binary cross entropy loss.

__init__(in_shape=None, att_dim=128, n_groups=8, distill_mode='maxmin', feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple, default: None ) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • n_groups (int, default: 8 ) –

    Number of groups to split the bag instances.

  • distill_mode (str, default: 'maxmin' ) –

    Distillation mode. Possible values: 'maxmin', 'max', 'afs'.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, mask=None, return_pseudo_pred=False, return_inst_cam=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_pseudo_pred (bool, default: False ) –

    If True, returns pseudo label logits in addition to Y_pred.

  • return_inst_cam (bool, default: False ) –

    If True, returns instance-level CAM values in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • inst_cam ( Tensor ) –

    Only returned when return_inst_cam=True. Instance-level CAM values of shape (batch_size, bag_size).

compute_loss(Y, X, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).