Dense MinCut pooling
torchmil.nn.gnns.dense_mincut_pool
dense_mincut_pool(x, adj, s, mask=None, temp=1.0)
Dense MinCut Pooling.
Adapts the implementation from torch_geometric.
Parameters:
-
x
(Tensor
) –Input tensor of shape
(batch_size, n_nodes, in_dim)
. -
adj
(Tensor
) –Adjacency tensor of shape
(batch_size, n_nodes, n_nodes)
. -
s
(Tensor
) –Dense learned assignments tensor of shape
(batch_size, n_nodes, n_cluster)
. -
mask
(Tensor
, default:None
) –Mask tensor of shape
(batch_size, n_nodes)
. -
temp
(float
, default:1.0
) –Temperature.
Returns:
-
x_
(Tensor
) –Pooled node feature tensor of shape
(batch_size, n_cluster, in_dim)
. -
adj_
(Tensor
) –Coarsened adjacency tensor of shape
(batch_size, n_cluster, n_cluster)
. -
mincut_loss
(Tensor
) –MinCut loss.
-
ortho_loss
(Tensor
) –Orthogonality loss.