Skip to content

CAMIL

torchmil.models.CAMIL

Bases: MILModel

Context-Aware Multiple Instance Learning (CAMIL) model, presented in the paper CAMIL: Context-Aware Multiple Instance Learning for Cancer Detection and Subtyping in Whole Slide Images.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, a global bag representation is computed using a NystromTransformerLayer layer,

\[ \mathbf{T} = \operatorname{NystromTransformerLayer}(\mathbf{X})\]

Next, a local bag representation is computed using the CAMILSelfAttention layer,

\[ \mathbf{L} = \operatorname{CAMILSelfAttention}(\mathbf{T}) \]

Finally, the local and global information is fused as

\[ \mathbf{M} = \operatorname{sigmoid}(\mathbf{L}) \odot \mathbf{L} + (1 - \operatorname{sigmoid}(\mathbf{L})) \odot \mathbf{T},\]

where \(\odot\) denotes element-wise multiplication and \(\operatorname{sigmoid}\) is the sigmoid function.

Lastly, the final bag representation is computed using the CAMILAttentionPool, modification of the Gatted Attention Pool mechanism. The bag representation is then fed into a linear classifier to predict the bag label.

__init__(in_shape, nystrom_att_dim=512, pool_att_dim=128, gated_pool=False, n_heads=4, n_landmarks=None, pinv_iterations=6, dropout=0.0, use_mlp=False, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • pool_att_dim (int, default: 128 ) –

    Attention dimension for the attention pooling layer.

  • gated_pool (bool, default: False ) –

    If True, use gated attention pooling.

  • nystrom_att_dim (int, default: 512 ) –

    Attention dimension for the Nystrom Transformer layer.

  • n_heads (int, default: 4 ) –

    Number of attention heads in the Nystrom Transformer layer.

  • n_landmarks (int, default: None ) –

    Number of landmarks in the Nystrom Transformer layer.

  • pinv_iterations (int, default: 6 ) –

    Number of iterations for computing the pseudo-inverse in the Nystrom Transformer layer.

  • dropout (float, default: 0.0 ) –

    Dropout rate of the Nystrom Transformer Layer.

  • use_mlp (bool, default: False ) –

    If True, use MLP in the Nystrom Transformer layer.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits.

forward(X, adj, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

compute_loss(Y, X, adj, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, adj, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).


torchmil.models.camil.CAMILSelfAttention

Bases: Module

Self-attention layer used in CAMIL: Context-Aware Multiple Instance Learning for Cancer Detection and Subtyping in Whole Slide Images. This layer computes the self-attention values using the local information of the bag. The local information is captured using an adjacency matrix, which measures the similarity between the embeddings of instances in the bag.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\), and an adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), this layer computes

\[ \mathbf{l}_i = \frac{\exp\left(\sum_{j=1}^N a_{ij} \mathbf{q}_i^\top \mathbf{k}_j \right)}{\sum_{k=1}^N \exp \left(\sum_{j=1}^N a_{kj} \mathbf{q}_k^\top \mathbf{k}_j \right)} \mathbf{v}_i,\]

where \(\mathbf{q}_i = \mathbf{W_q}\mathbf{x}_i\), \(\mathbf{k}_i = \mathbf{W_k}\mathbf{x}_i\), and \(\mathbf{v}_i = \mathbf{W_v}\mathbf{x}_i\) are the query, key, and value vectors, respectively. Finally, it returns \(\mathbf{L} = \left[ \mathbf{l}_1, \ldots, \mathbf{l}_N \right]^\top\).

forward(X, adj, mask=None)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag of features of shape (batch_size, bag_size, in_dim)

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size). If None, no masking is applied.

Returns:

  • L ( Tensor ) –

    Self-attention vectors with shape (batch_size, bag_size, in_dim)


torchmil.models.camil.CAMILAttentionPool

Bases: Module

Attention pooling layer as described in CAMIL: Context-Aware Multiple Instance Learning for Cancer Detection and Subtyping in Whole Slide Images.

Given a bag of features \(\mathbf{T} = \left[ \mathbf{t}_1, \ldots, \mathbf{t}_N \right]^\top \in \mathbb{R}^{N \times D}\) and \(\mathbf{M} = \left[ \mathbf{m}_1, \ldots, \mathbf{m}_N \right]^\top \in \mathbb{R}^{N \times D}\), this layer computes the final bag representation \(\mathbf{z}\) as

\[\begin{gather} \mathbf{f} = \mathbf{w}^\top \tanh(\mathbf{T} \mathbf{W} ) \odot \operatorname{sigmoid}(\mathbf{T} \mathbf{U}), \\ \mathbf{s} = \text{softmax}(\mathbf{f}), \\ \mathbf{z} = \mathbf{M}^\top \mathbf{s}, \end{gather}\]

where \(\mathbf{W}, \mathbf{U}\) and \(\mathbf{w}\) are learnable parameters. Note the difference with conventional AttentionPool layer, where the attention values and bag representation are computed from the same set of features.

forward(T, M, mask=None, return_att=False)

Forward pass.

Parameters:

  • T (Tensor) –

    (batch_size, bag_size, in_dim)

  • M (Tensor) –

    (batch_size, bag_size, in_dim)

  • mask (Tensor, default: None ) –

    (batch_size, bag_size)

  • return_att (bool, default: False ) –

    If True, returns attention values in addition to z.

Returns:

  • z ( Tensor ) –

    (batch_size, in_dim)

  • f ( Tensor ) –

    (batch_size, bag_size) if `return_att