Nyström Attention
torchmil.nn.attention.NystromAttention
Bases: Module
Nystrom attention, as described in the paper Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention.
Implementation based on the official code.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, learn_weights=True, n_landmarks=256, pinv_iterations=6)
Parameters:
-
in_dim
(int
) –Input dimension.
-
out_dim
(int
, default:None
) –Output dimension. If None, out_dim = in_dim.
-
att_dim
(int
, default:512
) –Attention dimension. Must be divisible by
n_heads
. -
n_heads
(int
, default:4
) –Number of heads.
-
learn_weights
(bool
, default:True
) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dim
must be divisible byn_heads
. -
n_landmarks
(int
, default:256
) –Number of landmarks.
-
pinv_iterations
(int
, default:6
) –Number of iterations for Moore-Penrose pseudo-inverse.
forward(x, mask=None, return_att=False)
Forward pass.
Parameters:
-
x
(Tensor
) –Input tensor of shape
(batch_size, seq_len, in_dim)
. -
mask
(Tensor
, default:None
) –Mask tensor of shape
(batch_size, seq_len)
. -
return_att
(bool
, default:False
) –Whether to return attention weights.
Returns:
-
y
(Tensor
) –Output tensor of shape
(batch_size, seq_len, att_dim)
. -
att
(Tensor
) –Only returned when
return_att=True
. Attention weights of shape(batch_size, n_heads, seq_len, seq_len)
.