iRPE Multihead Self-Attention
torchmil.nn.attention.iRPEMultiheadSelfAttention
Bases: Module
Multihead Self-Attention with image Relative Position Encoding (iRPE), as described in Rethinking and Improving Relative Position Encoding for Vision Transformer.
The iRPE implementation is based on the official codebase.
__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, dropout=0.0, learn_weights=True, rpe_ratio=1.9, rpe_method='product', rpe_mode='contextual', rpe_shared_head=True, rpe_skip=1, rpe_on='k')
Parameters:
-
in_dim
(int
) –Input dimension.
-
att_dim
(int
, default:512
) –Attention dimension. Must be divisible by
n_heads
. -
out_dim
(int
, default:None
) –Output dimension. If None,
out_dim
=in_dim
. -
n_heads
(int
, default:4
) –Number of heads.
-
dropout
(float
, default:0.0
) –Dropout rate.
-
learn_weights
(bool
, default:True
) –If True, learn the weights for query, key, and value. If False, q, k, and v are the same as the input, and therefore
in_dim
must be divisible byn_heads
. -
rpe_ratio
(float
, default:1.9
) –Relative position encoding ratio.
-
rpe_method
(str
, default:'product'
) –Relative position encoding method. Possible values: ['euc', 'quant', 'cross', 'product']
-
rpe_mode
(str
, default:'contextual'
) –Relative position encoding mode. Possible values: [None, 'bias', 'contextual']
-
rpe_shared_head
(bool
, default:True
) –Whether to share weights across heads.
-
rpe_skip
(int
, default:1
) –Relative position encoding skip. Possible values: [0, 1].
-
rpe_on
(str
, default:'k'
) –Where to apply relative positional encoding. Possible values: ['q', 'k', 'v', 'qk', 'kv', 'qkv'].
Note. When 'v' is in rpe_on
, rpe_mode
must be 'contextual'.
forward(x, mask=None, return_att=False, height=None, width=None)
Forward pass.
Parameters:
-
x
(Tensor
) –Input tensor of shape
(batch_size, seq_len, in_dim)
. -
mask
(Tensor
, default:None
) –Mask tensor of shape
(batch_size, seq_len)
. -
height
(int
, default:None
) –Height of the input sequence. If None,
height = floor(sqrt(seq_len))
. -
width
(int
, default:None
) –Width of the input sequence. If None,
width = floor(sqrt(seq_len))
.
Returns:
-
y
(Tensor
) –Output tensor of shape
(batch_size, seq_len, att_dim)
.