TransMIL
torchmil.models.TransMIL
Bases: MILModel
Method proposed in the paper TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, following Algorithm 2 in the paper, it performs sequence squaring, adds a class token, and applies the novel TPT module. This module consists of two Nyströmformer layers and the novel PPEG (Pyramid Positional Encoding Generator) layer.
Finally, a linear classifier is used to predict the bag label from the class token.
__init__(in_shape, att_dim=512, n_layers=2, n_heads=4, n_landmarks=None, pinv_iterations=6, dropout=0.0, use_mlp=False, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape
(tuple
) –Shape of input data expected by the feature extractor (excluding batch dimension).
-
att_dim
(int
, default:512
) –Embedding dimension. Should be divisible by
n_heads
. -
n_layers
(int
, default:2
) –Number of Nyströmformer layers.
-
n_heads
(int
, default:4
) –Number of heads in the Nyströmformer layer.
-
n_landmarks
(int
, default:None
) –Number of landmarks in the Nyströmformer layer.
-
pinv_iterations
(int
, default:6
) –Number of iterations for the pseudo-inverse in the Nyströmformer layer.
-
dropout
(float
, default:0.0
) –Dropout rate in the Nyströmformer layer.
-
use_mlp
(bool
, default:False
) –Whether to use a MLP after the Nyströmformer layer.
-
feat_ext
(Module
, default:Identity()
) –Feature extractor. By default, the identity function (no feature extraction).
-
criterion
(Module
, default:BCEWithLogitsLoss()
) –Loss function. By default, Binary Cross-Entropy loss from logits.
forward(X, return_att=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Input tensor of shape
(batch_size, bag_size, in_dim)
. -
return_att
(bool
, default:False
) –Whether to return the attention values.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention values of shape (batch_size, bag_size).
compute_loss(Y, X)
Compute loss given true bag labels.
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Input tensor of shape
(batch_size, bag_size, in_dim)
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss value.
predict(X, return_inst_pred=True)
Predict bag and (optionally) instance labels.
Parameters:
-
X
(Tensor
) –Input tensor of shape
(batch_size, bag_size, in_dim)
. -
return_inst_pred
(bool
, default:True
) –If
True
, returns instance labels predictions, in addition to bag label predictions.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention values of shape (batch_size, bag_size).