DSMIL
torchmil.models.DSMIL
Bases: MILModel
Dual-stream Multiple Instance Learning (DSMIL) model, proposed in the paper Dual-stream Multiple Instance Learning Network for Whole Slide Image Classification with Self-supervised Contrastive Learning.
Overview. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, two streams are used. The first stream uses an instance classifier \(c \ \colon \mathbb{R}^D \to \mathbb{R}\) (implemented as a linear layer) and retrieves the instance with the highest logit score,
Then, the second stream computes the bag representation \(\mathbf{z} \in \mathbb{R}^D\) as
where \(\mathbf{q}_i = \mathbf{W}_q \mathbf{x}_i\) and \(\mathbf{v}_i = \mathbf{W}_v \mathbf{x}_i\). This is similar to self-attention with the difference that query-key matching is performed only with the critical instance.
Finally, the bag representation is used to predict the bag label using a bag classifier implemented as a linear layer.
Loss function. By default, the model is trained end-to-end using the followind per-bag loss:
where \(\ell_{\text{BCE}}\) is the Binary Cross-Entropy loss, \(Y\) is the true bag label, \(\hat{Y}\) is the predicted bag label, and \(c(\mathbf{x}_m)\) is the predicted label of the critical instance.
__init__(in_shape=None, att_dim=128, nonlinear_q=False, nonlinear_v=False, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape
(tuple
, default:None
) –Shape of input data expected by the feature extractor (excluding batch dimension).
-
att_dim
(int
, default:128
) –Attention dimension.
-
nonlinear_q
(bool
, default:False
) –If True, apply nonlinearity to the query.
-
nonlinear_v
(bool
, default:False
) –If True, apply nonlinearity to the value.
-
dropout
(float
, default:0.0
) –Dropout rate.
-
feat_ext
(Module
, default:Identity()
) –Feature extractor.
-
criterion
(Module
, default:BCEWithLogitsLoss()
) –Loss function. By default, Binary Cross-Entropy loss from logits.
forward(X, mask=None, return_att=False, return_inst_pred=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_att
(bool
, default:False
) –If True, returns attention values (before normalization) in addition to
Y_pred
. -
return_inst_pred
(bool
, default:False
) –If True, returns instance label logits in addition to
Y_pred
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention values (before normalization) of shape (batch_size, bag_size). -
y_pred
(tuple[Tensor, Tensor]
) –Only returned when
return_inst_pred=True
. Instance label logits of shape(batch_size, bag_size)
.
compute_loss(Y, X, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss value.
predict(X, mask=None, return_inst_pred=False)
Predict bag and (optionally) instance labels.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_inst_pred
(bool
, default:False
) –If
True
, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –If
return_inst_pred=True
, returns instance labels predictions of shape(batch_size, bag_size)
.