Transformer base class
torchmil.nn.transformers.Encoder
Bases: Module
Generic Transformer encoder class.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\) and (optional) additional arguments, this module computes:
where \(\ldots\) denotes additional arguments.
The list of layers, \(\operatorname{Layer}^{l}\) for \(l = 1, \ldots, L\), is given by the layers
argument, and should be a subclass of Layer.
This module outputs \(\operatorname{Encoder}(\mathbf{X}) = \mathbf{X}^{L}\) if add_self=False
,
and \(\operatorname{Encoder}(\mathbf{X}) = \mathbf{X}^{L} + \mathbf{X}\) if add_self=True
.
__init__(layers, add_self=False)
Parameters:
-
layers
(ModuleList
) –List of encoder layers.
-
add_self
(bool
, default:False
) –Whether to add input to output. If True, the input and output dimensions must match.
forward(X, return_att=False, **kwargs)
Forward method.
Parameters:
-
X
(Tensor
) –Input tensor of shape
(batch_size, bag_size, in_dim)
.
Returns:
-
Y
(Tensor
) –Output tensor of shape
(batch_size, bag_size, in_dim)
.
torchmil.nn.transformers.Layer
Bases: Module
Generic Transformer layer class.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), and (optional) additional arguments, this module computes:
and outputs \(\mathbf{Y}\).
\(\operatorname{Att}\) is given by the att_module
argument, and \(\operatorname{MLP}\) is given by the mlp_module
argument.
__init__(att_module, in_dim, att_in_dim, out_dim=None, att_out_dim=None, use_mlp=True, mlp_module=None, dropout=0.0)
Parameters:
-
att_module
(Module
) –Attention module. Assumes input of shape
(batch_size, seq_len, att_in_dim)
and outputs of shape(batch_size, seq_len, att_out_dim)
. -
in_dim
(int
) –Input dimension.
-
att_in_dim
(int
) –Input dimension for the attention module.
-
out_dim
(int
, default:None
) –Output dimension. If None, out_dim = in_dim.
-
att_out_dim
(int
, default:None
) –Output dimension for the attention module. If None, att_out_dim = in_dim.
-
use_mlp
(bool
, default:True
) –Whether to use a MLP after the attention layer.
-
mlp_module
(Module
, default:None
) –MLP module.
-
dropout
(float
, default:0.0
) –Dropout rate.
forward(X, return_att=False, **kwargs)
Parameters:
-
X
(Tensor
) –Input tensor of shape
(batch_size, seq_len, in_dim)
. -
return_att
(bool
, default:False
) –If True, returns attention weights, of shape
(batch_size, n_heads, seq_len, seq_len)
. -
kwargs
(Any
, default:{}
) –Additional arguments for the attention module.
Returns:
-
Y
(Tensor
) –Output tensor of shape
(batch_size, seq_len, out_dim)
. -
Tensor
–If
return_att
is True, also returns attention weights, of shape(batch_size, n_heads, seq_len, seq_len)
.