Skip to content

Nyström Transformer


torchmil.nn.transformers.NystromTransformerEncoder

Bases: Encoder

Nystrom Transformer encoder with skip connections and layer normalization.

Given an input bag input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), it computes:

\[\begin{align*} \mathbf{X}^{0} & = \mathbf{X} \\ \mathbf{Z}^{l} & = \mathbf{X}^{l-1} + \operatorname{NystromSelfAttention}( \operatorname{LayerNorm}(\mathbf{X}^{l-1}) ), \quad l = 1, \ldots, L \\ \mathbf{X}^{l} & = \mathbf{Z}^{l} + \operatorname{MLP}(\operatorname{LayerNorm}(\mathbf{Z}^{l})), \quad l = 1, \ldots, L \\ \end{align*}\]

This module outputs \(\operatorname{TransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L}\) if add_self=False, and \(\operatorname{TransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L} + \mathbf{X}\) if add_self=True.

\(\operatorname{NystromSelfAttention}\) is implemented using the NystromAttention module, see NystromAttention.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=8, n_layers=4, n_landmarks=256, pinv_iterations=6, dropout=0.0, use_mlp=False, add_self=False)

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • n_heads (int, default: 8 ) –

    Number of heads.

  • n_layers (int, default: 4 ) –

    Number of layers.

  • n_landmarks (int, default: 256 ) –

    Number of landmarks.

  • pinv_iterations (int, default: 6 ) –

    Number of iterations for the pseudo-inverse.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • use_mlp (bool, default: False ) –

    Whether to use a MLP after the attention layer.

  • add_self (bool, default: False ) –

    Whether to add the input to the output. If True, att_dim must be equal to in_dim.

forward(X, mask=None, return_att=False)

Forward method.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, att_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    Whether to return attention weights.

Returns:

  • Y ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, att_dim).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention weights of shape (batch_size, n_heads, bag_size, bag_size).


torchmil.nn.transformers.NystromTransformerLayer

Bases: Layer

One layer of the NystromTransformer encoder.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), this module computes:

\[\begin{align*} \mathbf{Z} & = \mathbf{X} + \operatorname{NystromSelfAttention}( \operatorname{LayerNorm}(\mathbf{X}) ) \\ \mathbf{Y} & = \mathbf{Z} + \operatorname{MLP}(\operatorname{LayerNorm}(\mathbf{Z})), \\ \end{align*}\]

and outputs \(\mathbf{Y}\). \(\operatorname{NystromSelfAttention}\) is implemented using the NystromAttention module, see NystromAttention.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, learn_weights=True, n_landmarks=256, pinv_iterations=6, dropout=0.0, use_mlp=False)

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim

    Output dimension. If None, out_dim = in_dim.

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • n_landmarks (int, default: 256 ) –

    Number of landmarks.

  • pinv_iterations (int, default: 6 ) –

    Number of iterations for the pseudo-inverse.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • use_mlp (bool, default: False ) –

    Whether to use a MLP after the attention layer.

forward(X, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, att_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    Whether to return attention weights.

Returns:

  • X ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, att_dim).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention weights of shape (batch_size, n_heads, bag_size, bag_size).