Prob Smooth Attention Pool
torchmil.nn.attention.ProbSmoothAttentionPool
Bases: Module
Probabilistic Smooth Attention Pooling, proposed in in Probabilistic Smooth Attention for Deep Multiple Instance Learning in Medical Imaging and Smooth Attention for Deep Multiple Instance Learning: Application to CT Intracranial Hemorrhage Detection
Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times \texttt{in_dim}}\), this model computes an attention distribution \(q(\mathbf{f} \mid \mathbf{X}) = \mathcal{N}\left(\mathbf{f} \mid \mathbf{\mu}_{\mathbf{f}}, \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2) \right)\), where:
where \(\operatorname{MLP}\) is a multi-layer perceptron, and \(\mathbf{w}_{\mu},\mathbf{w}_{\sigma} \in \mathbb{R}^{2\texttt{att_dim} \times 1}\).
If covar_mode='zero'
, the variance vector \(\mathbf{\sigma}_{\mathbf{f}}^2\) is set to zero, resulting in a deterministic attention distribution.
Then, \(M\) samples from the attention distribution are drawn as \(\widehat{\mathbf{f}}^{(m)} \sim q(\mathbf{f} \mid \mathbf{X})\). With these samples, the bag representation is computed as:
where \(\widehat{\mathbf{F}} = \left[ \widehat{\mathbf{f}}^{(1)}, \ldots, \widehat{\mathbf{f}}^{(M)} \right]^\top \in \mathbb{R}^{N \times M}\).
Kullback-Leibler Divergence. Given a bag with adjancency matrix \(\mathbf{A}\), the KL divergence between the attention distribution and the prior distribution is computed as:
where \(\operatorname{const}\) is a constant term that does not depend on the parameters, \(\mathbf{\Sigma}_{\mathbf{f}} = \operatorname{diag}(\mathbf{\sigma}_{\mathbf{f}}^2)\), \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian matrix, and \(\mathbf{D}\) is the degree matrix of \(\mathbf{A}\).
__init__(in_dim=None, att_dim=128, covar_mode='diag', n_samples_train=1000, n_samples_test=5000)
Parameters:
-
in_dim
(int
, default:None
) –Input dimension. If not provided, it will be lazily initialized.
-
att_dim
(int
, default:128
) –Attention dimension.
-
covar_mode
(str
, default:'diag'
) –Covariance mode. Must be 'diag' or 'zero'.
-
n_samples_train
(int
, default:1000
) –Number of samples during training.
-
n_samples_test
(int
, default:5000
) –Number of samples during testing.
forward(X, adj=None, mask=None, return_att_samples=False, return_att_dist=False, return_kl_div=False, n_samples=None)
In the following, if covar_mode='zero'
then n_samples
is automatically set to 1 and diag_Sigma_f
is set to None.
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, dim)
. -
mask
(Tensor
, default:None
) –Mask of shape
(batch_size, bag_size)
. -
adj
(Tensor
, default:None
) –Adjacency matrix of shape
(batch_size, bag_size, bag_size)
. Only required whenreturn_kl_div=True
. -
return_att_samples
(bool
, default:False
) –If True, returns samples from the attention distribution
f
in addition toz
. -
return_att_dist
(bool
, default:False
) –If True, returns the attention distribution (
mu_f
,diag_Sigma_f
) in addition toz
. -
return_kl_div
(bool
, default:False
) –If True, returns the KL divergence between the attention distribution and the prior distribution.
-
n_samples
(int
, default:None
) –Number of samples to draw. If not provided, it will use
n_samples_train
during training andn_samples_test
during testing.
Returns:
-
z
(Tensor
) –Bag representation of shape
(batch_size, dim, n_samples)
. -
f
(Tensor
) –Samples from the attention distribution of shape
(batch_size, bag_size, n_samples)
. Only returned whenreturn_att_samples=True
. -
mu_f
(Tensor
) –Mean of the attention distribution of shape
(batch_size, bag_size, 1)
. Only returned whenreturn_att_dist=True
. -
diag_Sigma_f
(Tensor
) –Covariance of the attention distribution of shape
(batch_size, bag_size, 1)
. Only returned whenreturn_att_dist=True
. -
kl_div
(Tensor
) –KL divergence between the attention distribution and the prior distribution, of shape
()
. Only returned whenreturn_kl_div=True
.