Skip to content

GCNConv

torchmil.nn.gnns.GCNConv

Bases: Module

Implementation of a Graph Convolutional Network (GCN) layer.

Adapts the implementation from torch_geometric.

__init__(in_dim, out_dim=None, add_self_loops=False, learn_weights=False, layer_norm=False, normalize=False, dropout=0.0, activation=torch.nn.Identity(), bias=True)

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim (int, default: None ) –

    Output dimension.

  • add_self_loops (bool, default: False ) –

    Whether to add self-loops.

  • learn_weights (bool, default: False ) –

    Whether to use a linear layer after the convolution.

  • layer_norm (bool, default: False ) –

    Whether to use layer normalization.

  • normalize (bool, default: False ) –

    Whether to l2-normalize the output.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • activation (Module, default: Identity() ) –

    Activation function to apply after the convolution.

  • bias (bool, default: True ) –

    Whether to use bias.

forward(x, adj)

Parameters:

  • x

    Node features of shape (batch_size, n_nodes, in_dim).

  • adj

    Adjacency matrix of shape (batch_size, n_nodes, n_nodes).

Returns:

  • y ( Tensor ) –

    Output tensor of shape (batch_size, n_nodes, out_dim).