Skip to content

iRPE Multihead Self-Attention

torchmil.nn.attention.iRPEMultiheadSelfAttention

Bases: Module

Multihead Self-Attention with image Relative Position Encoding (iRPE), as described in Rethinking and Improving Relative Position Encoding for Vision Transformer.

The iRPE implementation is based on the official codebase.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k')

Parameters:

  • in_dim (int) –

    Input dimension.

  • att_dim (int, default: 512 ) –

    Attention dimension. Must be divisible by n_heads.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • learn_weights (bool, default: True ) –

    If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore in_dim must be divisible by n_heads.

  • rpe_ratio (float, default: 1.9 ) –

    Relative position encoding ratio.

  • rpe_method (str, default: 'product' ) –

    Relative position encoding method. Possible values: ['euc', 'quant', 'cross', 'product']

  • rpe_mode (str, default: 'contextual' ) –

    Relative position encoding mode. Possible values: [None, 'bias', 'contextual']

  • rpe_shared_head (bool, default: True ) –

    Whether to share weights across heads.

  • rpe_skip (int, default: 1 ) –

    Relative position encoding skip. Possible values: [0, 1].

  • rpe_on (str, default: 'k' ) –

    Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].

Note. When 'v' is in rpe_on, rpe_mode must be 'contextual'.

forward(x, mask=None, return_att=False, height=None, width=None)

Forward pass.

Parameters:

  • x (Tensor) –

    Input tensor of shape (batch_size, seq_len, in_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, seq_len).

  • height (int, default: None ) –

    Height of the input sequence. If None, height = floor(sqrt(seq_len)).

  • width (int, default: None ) –

    Width of the input sequence. If None, width = floor(sqrt(seq_len)).

Returns:

  • y ( Tensor ) –

    Output tensor of shape (batch_size, seq_len, att_dim).