DeepGCNLayer
torchmil.nn.gnns.DeepGCNLayer
Bases: Module
Implementation of a DeepGCN layer.
Adapts the implementation from torch_geometric.
__init__(conv=None, norm=None, act=None, block='plain', dropout=0.0)
Parameters:
-
conv
(Module
, default:None
) –Convolutional layer.
-
norm
(Module
, default:None
) –Normalization layer.
-
act
(Module
, default:None
) –Activation layer.
-
block
(str
, default:'plain'
) –Skip connection type. Possible values: 'res', 'res+', 'dense', 'plain'.
-
dropout
(float
, default:0.0
) –Dropout rate.
forward(x, adj)
Forward method.
Parameters:
-
x
(Tensor
) –Node features of shape
(batch_size, n_nodes, in_dim)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, n_nodes, n_nodes)
.
Returns:
-
y
(Tensor
) –Output tensor of shape
(batch_size, n_nodes, out_dim)
.