Skip to content

PatchGCN

torchmil.models.PatchGCN

Bases: Module

PatchGCN model, as proposed in Whole Slide Images are 2D Point Clouds: Context-Aware Survival Prediction using Patch-based Graph Convolutional Networks.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).

Then, a Graph Convolutional Network (GCN) and a Multi-Layer Perceptron (MLP) are used to transform the instance features,

\[\begin{gather} \mathbf{H} = \operatorname{GCN}(\mathbf{X}, \mathbf{A}) \in \mathbb{R}^{N \times \texttt{out_gcn_dim}}, \\ \mathbf{H} = \operatorname{MLP}(\mathbf{H}) \in \mathbb{R}^{N \times \texttt{hidden_dim}}, \end{gather}\]

where \(\texttt{out_gcn_dim} = \texttt{hidden_dim} \cdot \texttt{n_gcn_layers}\). These GCNs are implemented using the DeepGCN layer (see DeepGCNLayer) with GCNConv, LayerNorm, and ReLU activation (see GCNConv), along with residual connections and dense connections.

Then, attention values \(\mathbf{f} \in \mathbb{R}^{N \times 1}\) and the bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{hidden_dim}}\) are computed using the attention pooling mechanism (see Attention Pooling),

\[\begin{equation} \mathbf{z}, \mathbf{f} = \operatorname{AttentionPool}(\mathbf{H}). \end{equation}\]

Finally, the bag representation \(\mathbf{z}\) is fed into a classifier (one linear layer) to predict the bag label.

__init__(in_shape, n_gcn_layers=4, mlp_depth=1, hidden_dim=None, att_dim=128, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())

Parameters:

  • in_shape (tuple) –

    Shape of input data expected by the feature extractor (excluding batch dimension).

  • n_gcn_layers (int, default: 4 ) –

    Number of GCN layers.

  • mlp_depth (int, default: 1 ) –

    Number of layers in the MLP (applied after the GCN).

  • hidden_dim (int, default: None ) –

    Hidden dimension. If not provided, it will be set to the feature dimension.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • feat_ext (Module, default: Identity() ) –

    Feature extractor.

  • criterion (Module, default: BCEWithLogitsLoss() ) –

    Loss function.

forward(X, adj, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to Y_pred.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).

compute_loss(Y, X, adj, mask=None)

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss

predict(X, adj, mask=None, return_inst_pred=False)

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_inst_pred (bool, default: False ) –

    If True, returns instance predictions.

Returns:

  • Y_pred ( Tensor ) –

    Bag label logits of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).