Skip to content

GTP

torchmil.models.GTP

Bases: MILModel

Method proposed in the paper GTP: Graph-Transformer for Whole Slide Image Classification.

Forward pass. Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

The bags are processed using a Graph Convolutional Network (GCN) to extract high-level instance embeddings. This GCN leverages a graph \(\mathbf{A}\) constructed from the bag, where nodes correspond to patches, and edges are determined based on spatial adjacency:

\[ \mathbf{H} = \text{GCN}(\mathbf{X}, \mathbf{A}) \in \mathbb{R}^{N \times D}.\]

To reduce the number of nodes while preserving structural relationships, a min-cut pooling operation is applied:

\[ \mathbf{X}', \mathbf{A}' = \text{MinCutPool}(\mathbf{H}, \mathbf{A}).\]

The pooled graph is then passed through a Transformer encoder, where a class token is introduced:

\[ \mathbf{Z} = \text{Transformer}([\text{CLS}; \mathbf{X}']) \in \mathbb{R}^{(N' + 1) \times D}.\]

Finally, the class token representation is used for classification:

\[ \mathbf{z} = \mathbf{Z}_{0}, \quad Y_{\text{pred}} = \text{Classifier}(\mathbf{z}).\]

Optionally, GraphCAM can be used to generate class activation maps highlighting the most relevant regions for the classification decision.

Loss function. By default, the model is trained end-to-end using the followind per-bag loss:

\[ \ell = \ell_{\text{BCE}}(Y_{\text{pred}}, Y) + \ell_{\text{MinCut}}(\mathbf{X}, \mathbf{A}) + \ell_{\text{Ortho}}(\mathbf{X}, \mathbf{A}),\]

where \(\ell_{\text{BCE}}\) is the Binary Cross-Entropy loss, \(\ell_{\text{MinCut}}\) is the MinCut loss, and \(\ell_{\text{Ortho}}\) is the Orthogonality loss, computed during the min-cut pooling operation, see Dense MinCut Pooling.

__init__(in_shape, att_dim=512, n_clusters=100, n_layers=1, n_heads=8, use_mlp=True, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Class constructor.

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension). If not provided, it will be lazily initialized.

  • att_dim (int, default: 512 ) –

    Attention dimension for transformer encoder.

  • n_clusters (int, default: 100 ) –

    Number of clusters in mincut pooling.

  • n_layers (int, default: 1 ) –

    Number of layers in transformer encoder.

  • n_heads (int, default: 8 ) –

    Number of heads in transformer encoder.

  • use_mlp (bool, default: True ) –

    Whether to use MLP in transformer encoder.

  • dropout (float, default: 0.0 ) –

    Dropout rate in transformer encoder.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function. By default, Binary Cross-Entropy loss from logits for binary classification.

forward(X, adj, mask=None, return_cam=False, return_loss=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_cam (bool, default: False ) –

    If True, returns the class activation map in addition to Y_logits_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • cam ( Tensor ) –

    Only returned when return_cam=True. Class activation map of shape (batch_size, bag_size).

compute_loss(Y, X, adj, mask=None)

Compute loss given true bag labels.

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss value.

predict(X, adj, mask=None, return_inst_pred=True)

Predict bag and (optionally) instance labels.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: True ) –

    If True, returns instance labels predictions, in addition to bag label predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).