Skip to content

DeepGCNLayer

torchmil.nn.gnns.DeepGCNLayer

Bases: Module

Implementation of a DeepGCN layer.

Adapts the implementation from torch_geometric.

__init__(conv=None, norm=None, act=None, block='plain', dropout=0.0)

Parameters:

  • conv (Module, default: None ) –

    Convolutional layer.

  • norm (Module, default: None ) –

    Normalization layer.

  • act (Module, default: None ) –

    Activation layer.

  • block (str, default: 'plain' ) –

    Skip connection type. Possible values: 'res', 'res+', 'dense', 'plain'.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

forward(x, adj)

Forward method.

Parameters:

  • x (Tensor) –

    Node features of shape (batch_size, n_nodes, in_dim).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, n_nodes, n_nodes).

Returns:

  • y ( Tensor ) –

    Output tensor of shape (batch_size, n_nodes, out_dim).