Skip to content

Nyström Attention

torchmil.nn.attention.NystromAttention

Bases: Module

Nystrom attention, as described in the paper Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention.

Implementation based on the official code.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, learn_weights=True, n_landmarks=256, pinv_iterations=6)

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • att_dim (int, default: 512 ) –

    Attention dimension. Must be divisible by n_heads.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • learn_weights (bool, default: True ) –

    If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore in_dim must be divisible by n_heads.

  • n_landmarks (int, default: 256 ) –

    Number of landmarks.

  • pinv_iterations (int, default: 6 ) –

    Number of iterations for Moore-Penrose pseudo-inverse.

forward(x, mask=None, return_att=False)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor of shape (batch_size, seq_len, in_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, seq_len).

  • return_att (bool, default: False ) –

    Whether to return attention weights.

Returns:

  • y ( Tensor ) –

    Output tensor of shape (batch_size, seq_len, att_dim).

  • att ( Tensor ) –

    Only returned when return_att=True. Attention weights of shape (batch_size, n_heads, seq_len, seq_len).