Skip to content

Prob Smooth Attention Pool

torchmil.nn.attention.ProbSmoothAttentionPool

Bases: Module

Probabilistic Smooth Attention Pooling, proposed in in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging and Smooth Attention for Deep Multiple Instance Learning: Application to CT Intracranial Hemorrhage Detection

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model computes an attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\), where:

\[\begin{gather} \mathbf{H} = \operatorname{MLP}(\mathbf{X}) \in \mathbb{R}^{N \times 2\texttt{att_dim}}, \\ \mathbf{\mu}_{\mathbf{f}} = \mathbf{H}\mathbf{w}_{\mu} \in \mathbb{R}^{N}, \\ \log \mathbf{\sigma}_{\mathbf{f}}^2 = \mathbf{H}\mathbf{w}_{\sigma} \in \mathbb{R}^{N}, \end{gather}\]

where \(\operatorname{MLP}\) is a multi-layer perceptron, and \(\mathbf{w}_{\mu},\mathbf{w}_{\sigma} \in \mathbb{R}^{2\texttt{att_dim} \times 1}\). If covar_mode='zero', the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.

Then, \(M\) samples from the attention distribution are drawn as \(\widehat{\mathbf{f}}^{(m)} \sim q(\mathbf{f} \mid \mathbf{X})\). With these samples, the bag representation is computed as:

\[ \widehat{\mathbf{z}} = \operatorname{Softmax}(\widehat{\mathbf{F}}) \mathbf{X} \in \mathbb{R}^{\texttt{in_dim} \times M}, \]

where \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(M)} \right]^\top \in \mathbb{R}^{N \times M}\).

Kullback-Leibler Divergence. Given a bag with adjancency matrix \(\mathbf{A}\), the KL divergence between the attention distribution and the prior distribution is computed as:

\[ \ell_{\text{KL}} = \begin{cases} \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} \quad & \text{if } \texttt{covar_mode='zero'}, \\ \mathbf{\mu}_{\mathbf{f}}^\top \mathbf{L} \mathbf{\mu}_{\mathbf{f}} + \operatorname{Tr}(\mathbf{L} \mathbf{\Sigma}_{\mathbf{f}}) - \frac{1}{2}\log \det( \mathbf{\Sigma}_{\mathbf{f}} ) + \operatorname{const} \quad & \text{if } \texttt{covar_mode='diag'}, \\ \end{cases} \]

where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\).

__init__(in_dim=None, att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000)

Parameters:

  • in_dim (int, default: None ) –

    Input dimension. If not provided, it will be lazily initialized.

  • att_dim (int, default: 128 ) –

    Attention dimension.

  • covar_mode (str, default: 'diag' ) –

    Covariance mode. Must be 'diag' or 'zero'.

  • n_samples_train (int, default: 1000 ) –

    Number of samples during training.

  • n_samples_test (int, default: 5000 ) –

    Number of samples during testing.

forward(X, adj=None, mask=None, return_att_samples=False, return_att_dist=False, return_kl_div=False, n_samples=None)

In the following, if covar_mode='zero' then n_samples is automatically set to 1 and diag_Sigma_f is set to None.

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, dim).

  • mask (Tensor, default: None ) –

    Mask of shape (batch_size, bag_size).

  • adj (Tensor, default: None ) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size). Only required when return_kl_div=True.

  • return_att_samples (bool, default: False ) –

    If True, returns samples from the attention distribution f in addition to z.

  • return_att_dist (bool, default: False ) –

    If True, returns the attention distribution (mu_f, diag_Sigma_f) in addition to z.

  • return_kl_div (bool, default: False ) –

    If True, returns the KL divergence between the attention distribution and the prior distribution.

  • n_samples (int, default: None ) –

    Number of samples to draw. If not provided, it will use n_samples_train during training and n_samples_test during testing.

Returns:

  • z ( Tensor ) –

    Bag representation of shape (batch_size, dim, n_samples).

  • f ( Tensor ) –

    Samples from the attention distribution of shape (batch_size, bag_size, n_samples). Only returned when return_att_samples=True.

  • mu_f ( Tensor ) –

    Mean of the attention distribution of shape (batch_size, bag_size, 1). Only returned when return_att_dist=True.

  • diag_Sigma_f ( Tensor ) –

    Covariance of the attention distribution of shape (batch_size, bag_size, 1). Only returned when return_att_dist=True.

  • kl_div ( Tensor ) –

    KL divergence between the attention distribution and the prior distribution, of shape (). Only returned when return_kl_div=True.