IIBMIL
torchmil.models.IIBMIL
Bases: Module
Integrated Instance-Level and Bag-Level Multiple Instance Learning (IIB-MIL) model, proposed in the paper IIB-MIL: Integrated Instance-Level and Bag-Level Multiple Instances Learning with Label Disambiguation for Pathological Image Analysis.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, a TransformerEncoder is applied to transform the instance features using context information. Subsequently, the model uses bag-level and instance-level supervision:
Bag-level supervision: The instances are aggregated into a class token using a transformer decoder. A linear layer is then applied to predict the bag label.
Instance-level supervision: Consists of four steps.
- Using an instance classifier, obtain the probability of instance \(i\) belonging to class \(c\), denoted as \(p_{i,c}\).
- The prototype \(\mathbf{p}_{c,t} \in \mathbf{R}^{D}\) of class \(c\) at time \(t\) is updated using a momentum update rule based on the set of instances with the top \(k\) highest probabilities of belonging to class \(c\). Writing \(\mathbf{P}_t = \left[ \mathbf{p}_{1,t}, \ldots, \mathbf{p}_{C,t} \right]^\top \in \mathbb{R}^{C \times D}\), the prototype label \(z_{i}\) of each instance is obtained as \(z_{i} = \text{argmax}_{c} \ \mathbf{P} \mathbf{x}_i\).
- Compute instance-level soft labels using the prototype labels and a momentum update.
- Compute the instance-level cross-entropy loss using the soft labels and the instance classifier.
__init__(in_shape=None, att_dim=256, n_layers_encoder=1, n_layers_decoder=1, use_mlp_encoder=True, use_mlp_decoder=False, n_heads=4, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape
(tuple
, default:None
) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
att_dim
(int
, default:256
) –Attention dimension.
-
n_layers_encoder
(int
, default:1
) –Number of layers in the transformer encoder.
-
n_layers_decoder
(int
, default:1
) –Number of layers in the transformer decoder.
-
use_mlp_encoder
(bool
, default:True
) –If True, uses a multi-layer perceptron (MLP) in the encoder.
-
use_mlp_decoder
(bool
, default:False
) –If True, uses a multi-layer perceptron (MLP) in the decoder.
-
n_heads
(int
, default:4
) –Number of attention heads.
-
feat_ext
(Module
, default:Identity()
) –Feature extractor.
-
criterion
(Module
, default:BCEWithLogitsLoss()
) –Loss function. By default, Binary Cross-Entropy loss from logits.
forward(X, mask=None, return_inst_pred=False, return_X_enc=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_inst_pred
(bool
, default:False
) –If True, returns attention values (before normalization) in addition to
Y_pred
. -
return_X_enc
(bool
, default:False
) –If True, returns instance embeddings in addition to
Y_pred
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –Only returned when
return_inst_pred=True
. Instance label logits of shape(batch_size, bag_size)
. -
X_enc
(Tensor
) –Only returned when
return_X_enc=True
. Instance embeddings of shape(batch_size, bag_size, att_dim)
.
compute_loss(Y, X, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss value.
predict(X, mask=None, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_inst_pred
(bool
, default:True
) –If
True
, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –If
return_inst_pred=True
, returns instance labels predictions of shape(batch_size, bag_size)
.
update_prototypes(X, mask=None, proto_m=0.9)
Update prototypes.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
proto_m
(float
, default:0.9
) –Momentum for updating prototypes
Returns:
-
None
–None