Skip to content

Multihead Self-Attention

torchmil.nn.attention.MultiheadSelfAttention

Bases: Module

The Multihead Self Attention module, as described in Attention is All You Need.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this module computes:

\[\begin{gather*} \mathbf{Q} = \mathbf{X}\mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X}\mathbf{W}_K, \quad \mathbf{V} = \mathbf{X}\mathbf{W}_V,\\ \mathbf{Y} = \operatorname{Softmax}\left( \frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d}} \right) \mathbf{V}, \end{gather*}\]

where \(d = \texttt{att_dim}\) and \(\mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\) are learnable weight matrices.

If \(\texttt{out_dim} \neq \texttt{att_dim}\), \(\mathbf{Y}\) is passed through a linear layer with output dimension \(\texttt{out_dim}\).

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True)

Parameters:

  • in_dim (int) –

    Input dimension.

  • att_dim (int, default: 512 ) –

    Attention dimension, must be divisible by n_heads.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • learn_weights (bool, default: True ) –

    If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore in_dim must be divisible by n_heads.

forward(x, mask=None, return_att=False)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor of shape (batch_size, seq_len, in_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, seq_len).

Returns: y: Output tensor of shape (batch_size, seq_len, att_dim). att: Only returned when return_att=True. Attention weights of shape (batch_size, n_heads, seq_len, seq_len).