TransformerProbSmoothABMIL
torchmil.models.TransformerProbSmoothABMIL
Bases: MILModel
Transformer Attention-based Multiple Instance Learning model, with probabilistic attention-based pooling. Proposed in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\). Then, it transforms the instance features using a transformer encoder,
Subsequently, it aggregates the instance features into a bag representation using a probabilistic attention-based pooling mechanism, as detailed in ProbSmoothAttentionPool.
Specifically, it computes a mean vector \(\mathbf{\mu}_{\mathbf{f}} \in \mathbb{R}^N\) and a variance vector \(\mathbf{\sigma}_{\mathbf{f}^2} \in \mathbb{R}^N\) that define the attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\),
If covar_mode='zero'
, the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.
Then, \(m\) attention vectors \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(m)} \right]^\top \in \mathbb{R}^{m \times N}\) are sampled from the attention distribution. The bag representation \(\widehat{\mathbf{z}} \in \mathbb{R}^{m \times D}\) is then computed as:
The bag representation \(\widehat{\mathbf{z}}\) is fed into a classifier, implemented as a linear layer, to produce bag label predictions \(Y_{\text{pred}} \in \mathbb{R}^{m}\).
Notably, the attention distribution naturally induces a distribution over the bag label predictions. This model thus generates multiple predictions for each bag, corresponding to different samples from this distribution.
Regularization. The probabilistic pooling mechanism introduces a regularization term to the loss function that encourages smoothness in the attention values. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the regularization term corresponds to
where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\). This term is then averaged for all bags in the batch and added to the loss function.
__init__(in_shape=None, pool_att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000, feat_ext=torch.nn.Identity(), transf_att_dim=512, transf_n_layers=1, transf_n_heads=8, transf_use_mlp=True, transf_add_self=True, transf_dropout=0.0, criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape
(tuple
, default:None
) –Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.
-
pool_att_dim
(int
, default:128
) –Attention dimension.
-
covar_mode
(str
, default:'diag'
) –Covariance mode for the Gaussian prior. Possible values: 'diag', 'full'.
-
n_samples_train
(int
, default:1000
) –Number of samples for training.
-
n_samples_test
(int
, default:5000
) –Number of samples for testing.
-
feat_ext
(Module
, default:Identity()
) –Feature extractor.
-
transf_att_dim
(int
, default:512
) –Attention dimension for transformer encoder.
-
transf_n_layers
(int
, default:1
) –Number of layers in transformer encoder.
-
transf_n_heads
(int
, default:8
) –Number of heads in transformer encoder.
-
transf_use_mlp
(bool
, default:True
) –Whether to use MLP in transformer encoder.
-
transf_add_self
(bool
, default:True
) –Whether to add input to output in transformer encoder.
-
transf_dropout
(float
, default:0.0
) –Dropout rate in transformer encoder.
-
criterion
(Module
, default:BCEWithLogitsLoss()
) –Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.
forward(X, adj, mask=None, return_att=False, return_samples=False, return_kl_div=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. Only required whenreturn_kl_div=True
. -
return_att
(bool
, default:False
) –If True, returns attention values (before normalization) in addition to
Y_pred
. -
return_samples
(bool
, default:False
) –If True and
return_att=True
, the attention values returned are samples from the attention distribution. -
return_kl_div
(bool
, default:False
) –If True, returns the KL divergence between the attention distribution and the prior distribution.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size, n_samples)
ifreturn_samples=True
, else(batch_size,)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention values (before normalization) of shape(batch_size, bag_size, n_samples)
ifreturn_samples=True
, else(batch_size, bag_size)
. -
kl_div
(Tensor
) –Only returned when
return_kl_div=True
. KL divergence between the attention distribution and the prior distribution of shape()
.
compute_loss(Y, X, adj, mask=None)
Compute loss given true bag labels.
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss value and the KL divergence between the attention distribution and the prior distribution.
predict(X, adj, mask=None, return_inst_pred=True, return_samples=False)
Predict bag and (optionally) instance labels.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_inst_pred
(bool
, default:True
) –If True, returns the attention values as instance labels predictions, in addition to bag label predictions.
-
return_samples
(bool
, default:False
) –If True and
return_inst_pred=True
, the instance label predictions returned are samples from the instance label distribution.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –Only returned when
return_inst_pred=True
. Attention values (before normalization) of shape(batch_size, bag_size)
ifreturn_samples=False
, else(batch_size, bag_size, n_samples)
.