Multihead Self-Attention
torchmil.nn.attention.MultiheadSelfAttention
Bases: Module
The Multihead Self Attention module, as described in Attention is All You Need.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this module computes:
where \(d = \texttt{att_dim}\) and \(\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\) are learnable weight matrices.
If \(\texttt{out_dim} \neq \texttt{att_dim}\), \(\mathbf{Y}\) is passed through a linear layer with output dimension \(\texttt{out_dim}\).
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True)
Parameters:
-
in_dim
(int
) –Input dimension.
-
att_dim
(int
, default:512
) –Attention dimension, must be divisible by
n_heads
. -
out_dim
(int
, default:None
) –Output dimension. If None,
out_dim
=in_dim
. -
n_heads
(int
, default:4
) –Number of heads.
-
dropout
(float
, default:0.0
) –Dropout rate.
-
learn_weights
(bool
, default:True
) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dim
must be divisible byn_heads
.
forward(x, mask=None, return_att=False)
Forward pass.
Parameters:
-
x
(Tensor
) –Input tensor of shape
(batch_size, seq_len, in_dim)
. -
mask
(Tensor
, default:None
) –Mask tensor of shape
(batch_size, seq_len)
.
Returns:
y: Output tensor of shape (batch_size, seq_len, att_dim)
.
att: Only returned when return_att=True
. Attention weights of shape (batch_size, n_heads, seq_len, seq_len)
.