PatchGCN
torchmil.models.PatchGCN
Bases: Module
PatchGCN model, as proposed in Whole Slide Images are 2D Point Clouds: Context-Aware Survival Prediction using Patch-based Graph Convolutional Networks.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times P}\) with adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), the model optionally applies a feature extractor, \(\text{FeatExt}(\cdot)\), to transform the instance features: \(\mathbf{X} = \text{FeatExt}(\mathbf{X}) \in \mathbb{R}^{N \times D}\).
Then, a Graph Convolutional Network (GCN) and a Multi-Layer Perceptron (MLP) are used to transform the instance features,
where \(\texttt{out_gcn_dim} = \texttt{hidden_dim} \cdot \texttt{n_gcn_layers}\). These GCNs are implemented using the DeepGCN layer (see DeepGCNLayer) with GCNConv, LayerNorm, and ReLU activation (see GCNConv), along with residual connections and dense connections.
Then, attention values \(\mathbf{f} \in \mathbb{R}^{N \times 1}\) and the bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{hidden_dim}}\) are computed using the attention pooling mechanism (see Attention Pooling),
Finally, the bag representation \(\mathbf{z}\) is fed into a classifier (one linear layer) to predict the bag label.
__init__(in_shape, n_gcn_layers=4, mlp_depth=1, hidden_dim=None, att_dim=128, dropout=0.0, feat_ext=torch.nn.Identity(), criterion=torch.nn.BCEWithLogitsLoss())
Parameters:
-
in_shape
(tuple
) –Shape of input data expected by the feature extractor (excluding batch dimension).
-
n_gcn_layers
(int
, default:4
) –Number of GCN layers.
-
mlp_depth
(int
, default:1
) –Number of layers in the MLP (applied after the GCN).
-
hidden_dim
(int
, default:None
) –Hidden dimension. If not provided, it will be set to the feature dimension.
-
att_dim
(int
, default:128
) –Attention dimension.
-
dropout
(float
, default:0.0
) –Dropout rate.
-
feat_ext
(Module
, default:Identity()
) –Feature extractor.
-
criterion
(Module
, default:BCEWithLogitsLoss()
) –Loss function.
forward(X, adj, mask=None, return_att=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_att
(bool
, default:False
) –If True, returns attention values (before normalization) in addition to
Y_pred
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention values (before normalization) of shape (batch_size, bag_size).
compute_loss(Y, X, adj, mask=None)
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss
predict(X, adj, mask=None, return_inst_pred=False)
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_inst_pred
(bool
, default:False
) –If True, returns instance predictions.
Returns:
-
Y_pred
(Tensor
) –Bag label logits of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –If
return_inst_pred=True
, returns instance labels predictions of shape(batch_size, bag_size)
.