Skip to content

Attention Pool

torchmil.nn.attention.AttentionPool

Bases: Module

Attention-based pooling, as proposed in the paper Attention-based Multiple Instance Learning.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{in_dim}}\) as,

\[ \mathbf{z} = \mathbf{X}^\top \operatorname{Softmax}(\mathbf{f}) = \sum_{n=1}^N s_n \mathbf{x}_n, \]

where \(\mathbf{f} = \operatorname{MLP}(\mathbf{X}) \in \mathbb{R}^{N}\) are the attention values and \(s_n\) is the normalized attention score for the \(n\)-th instance.

To compute the attention values, the \(\operatorname{MLP}\) is defined as

\[\begin{equation} \operatorname{MLP}(\mathbf{X}) = \begin{cases} \operatorname{act}(\mathbf{X}\mathbf{W}_1)\mathbf{w}, & \text{if }\texttt{gated=False}, \\ \left(\operatorname{act}(\mathbf{X}\mathbf{W}_1)\odot\operatorname{sigm}(\mathbf{X}\mathbf{W}_2)\right)\mathbf{w}, & \text{if }\texttt{gated=True}, \end{cases} \end{equation}\]

where \(\mathbf{W}_1 \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\), \(\mathbf{W}_2 \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\), \(\mathbf{w} \in \mathbb{R}^{\texttt{att_dim}}\), \(\operatorname{act} \ \colon \mathbb{R} \to \mathbb{R}\) is the activation function, \(\operatorname{sigm} \ \colon \mathbb{R} \to \left] 0, 1 \right[\) is the sigmoid function, and \(\odot\) denotes element-wise multiplication.

__init__(in_dim=None, att_dim=128, act='tanh', gated=False)

Parameters:

  • in_dim (int, default: None ) –

    Input dimension. If not provided, it will be lazily initialized.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • act (str, default: 'tanh' ) –

    Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.

  • gated (bool, default: False ) –

    If True, use gated attention.

forward(X, mask=None, return_att=False)

Forward pass.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, in_dim).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention values (before normalization) in addition to z.

Returns:

  • z ( Tensor ) –

    Bag representation of shape (batch_size, in_dim).

  • f ( Tensor ) –

    Only returned when return_att=True. Attention values (before normalization) of shape (batch_size, bag_size).