Skip to content

iRPE Transformer


torchmil.nn.transformers.iRPETransformerEncoder

Bases: Encoder

Given an input bag input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), it computes:

\[\begin{align*} \mathbf{X}^{0} & = \mathbf{X} \\ \mathbf{Z}^{l} & = \mathbf{X}^{l-1} + \operatorname{iRPESelfAttention}( \operatorname{LayerNorm}(\mathbf{X}^{l-1}) ), \quad l = 1, \ldots, L \\ \mathbf{X}^{l} & = \mathbf{Z}^{l} + \operatorname{MLP}(\operatorname{LayerNorm}(\mathbf{Z}^{l})), \quad l = 1, \ldots, L. \\ \end{align*}\]

See iRPEMultiheadSelfAttention for more details about \(\operatorname{iRPESelfAttention}\).

This module outputs \(\operatorname{TransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L}\) if add_self=False, and \(\operatorname{TransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L} + \mathbf{X}\) if add_self=True.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, n_layers=4, use_mlp=True, add_self=False, dropout=0.0, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k')

Class constructor

Parameters:

  • in_dim (int) –

    Input dimension.

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • n_layers (int, default: 4 ) –

    Number of layers.

  • use_mlp (bool, default: True ) –

    Whether to use feedforward layer.

  • add_self (bool, default: False ) –

    Whether to add input to output.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • rpe_ratio (float, default: 1.9 ) –

    Relative position encoding ratio.

  • rpe_method (str, default: 'product' ) –

    Relative position encoding method. Possible values: ['euc', 'quant', 'cross', 'product']

  • rpe_mode (str, default: 'contextual' ) –

    Relative position encoding mode. Possible values: [None, 'bias', 'contextual']

  • rpe_shared_head (bool, default: True ) –

    Whether to share weights across heads.

  • rpe_skip (int, default: 1 ) –

    Relative position encoding skip. Possible values: [0, 1].

  • rpe_on (str, default: 'k' ) –

    Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].

forward(X, return_att=False)

Forward method.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • return_att (bool, default: False ) –

    If True, returns attention weights, of shape (n_layers, batch_size, n_heads, bag_size, bag_size).

Returns:

  • Y ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, in_dim).


torchmil.nn.transformers.iRPETransformerLayer

Bases: Layer

Transformer layer with image Relative Position Encoding (iRPE), as described in Rethinking and Improving Relative Position Encoding for Vision Transformer.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), this module computes:

\[\begin{align*} \mathbf{Z} & = \mathbf{X} + \operatorname{iRPESelfAttention}( \operatorname{LayerNorm}(\mathbf{X}) ) \\ \mathbf{Y} & = \mathbf{Z} + \operatorname{MLP}(\operatorname{LayerNorm}(\mathbf{Z})), \\ \end{align*}\]

and outputs \(\mathbf{Y}\). See iRPEMultiheadSelfAttention for more details about \(\operatorname{iRPESelfAttention}\).

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, use_mlp=True, dropout=0.0, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k')

Class constructor.

Parameters:

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • in_dim (int) –

    Input dimension. If None, in_dim = att_dim.

  • out_dim

    Output dimension. If None, out_dim = in_dim.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • use_mlp (bool, default: True ) –

    Whether to use feedforward layer.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

  • rpe_ratio (float, default: 1.9 ) –

    Relative position encoding ratio.

  • rpe_method (str, default: 'product' ) –

    Relative position encoding method. Possible values: ['euc', 'quant', 'cross', 'product']

  • rpe_mode (str, default: 'contextual' ) –

    Relative position encoding mode. Possible values: [None, 'bias', 'contextual']

  • rpe_shared_head (bool, default: True ) –

    Whether to share weights across heads.

  • rpe_skip (int, default: 1 ) –

    Relative position encoding skip. Possible values: [0, 1].

  • rpe_on (str, default: 'k' ) –

    Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].

forward(X, return_att=False)

Forward method.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

Returns:

  • Y ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, out_dim).