Skip to content

Dense MinCut pooling

torchmil.nn.gnns.dense_mincut_pool

dense_mincut_pool(x, adj, s, mask=None, temp=1.0)

Dense MinCut Pooling.

Adapts the implementation from torch_geometric.

Parameters:

  • x (Tensor) –

    Input tensor of shape (batch_size, n_nodes, in_dim).

  • adj (Tensor) –

    Adjacency tensor of shape (batch_size, n_nodes, n_nodes).

  • s (Tensor) –

    Dense learned assignments tensor of shape (batch_size, n_nodes, n_cluster).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, n_nodes).

  • temp (float, default: 1.0 ) –

    Temperature.

Returns:

  • x_ ( Tensor ) –

    Pooled node feature tensor of shape (batch_size, n_cluster, in_dim).

  • adj_ ( Tensor ) –

    Coarsened adjacency tensor of shape (batch_size, n_cluster, n_cluster).

  • mincut_loss ( Tensor ) –

    MinCut loss.

  • ortho_loss ( Tensor ) –

    Orthogonality loss.