Skip to content

Mean Pool

torchmil.nn.MeanPool

Bases: Module

Mean pooling aggregation.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), this model aggregates the instance features into a bag representation \(\mathbf{z} \in \mathbb{R}^{D}\) as,

\[ \mathbf{z} = \frac{1}{N} \sum_{n=1}^{N} \mathbf{x}_n. \]
__init__()
forward(X, mask=None)

Forward pass.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, bag_size).

Returns: z: Output tensor of shape (batch_size, in_dim).