Sm Attention Pool
torchmil.nn.attention.SmAttentionPool
Bases: Module
Attention-based pooling with the Sm operator, as proposed in Sm: enhanced localization in Multiple Instance Learning for medical imaging classification.
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{\texttt{in_dim}}\) as,
where \(s_n\) is the normalized attention score for the \(n\)-th instance.
To compute the attention values, \(\operatorname{SmMLP}\) is defined as \(\operatorname{SmMLP}(\mathbf{X}) = \mathbf{Y}^L\) where
where \(\mathbf{W^0} \in \mathbb{R}^{\texttt{in_dim} \times \texttt{att_dim}}\), \(\mathbf{W}^l \in \mathbb{R}^{\texttt{att_dim} \times \texttt{att_dim}}\), \(\mathbf{w} \in \mathbb{R}^{\texttt{att_dim} \times 1}\), \(\operatorname{act} \ \colon \mathbb{R} \to \mathbb{R}\) is the activation function, and \(\texttt{Sm}\) is the Sm operator, see Sm for more details.
Note: If sm_pre=True
, the Sm operator is applied before \(\operatorname{SmMLP}\). If sm_post=True
, the Sm operator is applied after \(\operatorname{SmMLP}\).
__init__(in_dim, att_dim=128, act='gelu', sm_mode='approx', sm_alpha='trainable', sm_layers=1, sm_steps=10, sm_pre=False, sm_post=False, sm_spectral_norm=False)
Parameters:
-
in_dim
(int
) –Input dimension.
-
att_dim
(int
, default:128
) –Attention dimension.
-
act
(str
, default:'gelu'
) –Activation function for attention. Possible values: 'tanh', 'relu', 'gelu'.
-
sm_mode
(str
, default:'approx'
) –Mode for the Sm operator. Possible values: 'approx', 'exact'.
-
sm_alpha
(Union[float, str]
, default:'trainable'
) –Alpha value for the Sm operator. If 'trainable', alpha is trainable.
-
sm_layers
(int
, default:1
) –Number of layers that use the Sm operator.
-
sm_steps
(int
, default:10
) –Number of steps for the Sm operator.
-
sm_pre
(bool
, default:False
) –If True, apply Sm operator before the attention pooling.
-
sm_post
(bool
, default:False
) –If True, apply Sm operator after the attention pooling.
-
sm_spectral_norm
(bool
, default:False
) –If True, apply spectral normalization to all linear layers.
forward(X, adj, mask=None, return_att=False)
Forward pass.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, in_dim)
. -
adj
(Tensor
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
return_att
(bool
, default:False
) –If True, returns attention values (before normalization) in addition to
z
.
Returns:
-
z
(Tensor
) –Bag representation of shape
(batch_size, in_dim)
. -
f
(Tensor
) –Only returned when
return_att=True
. Attention values (before normalization) of shape (batch_size, bag_size).