Skip to content

Sm operator

torchmil.nn.Sm

Bases: Module

The \(\texttt{Sm}\) operator, proposed in the paper \(\texttt{Sm}\): enhanced localization in Multiple Instance Learning for medical imaging classification.

Given an input graph with node features \(\mathbf{U} \in \mathbb{R}^{N \times D}\) and adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), in the exact mode the \(\texttt{Sm}\) operator is defined as:

\[\begin{align} \texttt{Sm}(\mathbf{U}) = ( \mathbf{I} + \gamma \mathbf{L} )^{-1} \mathbf{U}, \end{align}\]

where \(\gamma \in (0, \infty)\) is a hyperparameter, \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian, and \(\mathbf{D}\) is the degree matrix. If mode='approx', the \(\texttt{Sm}\) operator is approximated as \(\texttt{Sm}(\mathbf{U}) = G(T)\), where

\[\begin{align} G(0) = \mathbf{U}, \quad G(t) = \alpha ( \mathbf{I} - \mathbf{L} ) G(t-1) + (1-\alpha) \mathbf{U}, \end{align}\]

for \(t \in \{1, \ldots, T\}\), and \(\alpha \in (0, 1)\) is a hyperparameter.

__init__(alpha='trainable', num_steps=10, mode='approx')

Parameters:

  • alpha (Union[float, str], default: 'trainable' ) –

    Alpha value for the Sm operator. If 'trainable', alpha is a trainable parameter.

  • num_steps (int, default: 10 ) –

    Number of steps to approximate the exact Sm operator.

  • mode (str, default: 'approx' ) –

    Mode of the Sm operator. Possible values: 'approx', 'exact'.

forward(f, adj_mat)

Forward method.

Parameters:

  • f (Tensor) –

    Input tensor of shape (batch_size, bag_size, ...).

  • adj_mat (Tensor) –

    Adjacency matrix tensor of shape (batch_size, bag_size, bag_size). Sparse tensor is supported.

Returns:

  • g ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, ...).


torchmil.nn.ApproxSm

Bases: Module

\(\texttt{Sm}\) operator in the approximate mode, proposed in the paper \(\texttt{Sm}\): enhanced localization in Multiple Instance Learning for medical imaging classification.

Given an input graph with node features \(\mathbf{U} \in \mathbb{R}^{N \times D}\) and adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), it computes \(\texttt{Sm}(\mathbf{U}) = G(T)\), where

\[\begin{align} G(0) = \mathbf{U}, \quad G(t) = \alpha ( \mathbf{I} - \mathbf{L} ) G(t-1) + (1-\alpha) \mathbf{U}, \end{align}\]

for \(t \in \{1, \ldots, T\}\), and \(\alpha \in (0, 1)\) is a hyperparameter.

__init__(alpha='trainable', num_steps=10)

Parameters:

  • alpha (Union[float, str], default: 'trainable' ) –

    Alpha value for the Sm operator. If 'trainable', alpha is a trainable parameter.

  • num_steps (int, default: 10 ) –

    Number of steps to approximate the exact Sm operator.

forward(f, adj_mat)

Forward method.

Parameters:

  • f (Tensor) –

    Input tensor of shape (batch_size, bag_size, ...).

  • adj_mat (Tensor) –

    Adjacency matrix tensor of shape (batch_size, bag_size, bag_size). Sparse tensor is supported.

Returns:

  • g ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, ...).


torchmil.nn.ExactSm

Bases: Module

\(\texttt{Sm}\) operator in the exact mode, proposed in the paper \(\texttt{Sm}\): enhanced localization in Multiple Instance Learning for medical imaging classification.

Given an input graph with node features \(\mathbf{U} \in \mathbb{R}^{N \times D}\) and adjacency matrix \(\mathbf{A} \in \mathbb{R}^{N \times N}\), it computes

\[\begin{align} \texttt{Sm}(\mathbf{U}) = ( \mathbf{I} + \gamma \mathbf{L} )^{-1} \mathbf{U}, \end{align}\]

where \(\gamma \in (0, \infty)\) is a hyperparameter, \(\mathbf{L} = \mathbf{D} - \mathbf{A}\) is the graph Laplacian, and \(\mathbf{D}\) is the degree matrix.

__init__(alpha='trainable')

Parameters:

  • alpha (Union[float, str], default: 'trainable' ) –

    Alpha value for the Sm operator. If 'trainable', alpha is a trainable parameter.

forward(f, adj_mat)

Forward method.

Parameters:

  • f (Tensor) –

    Input tensor of shape (batch_size, bag_size, ...).

  • adj_mat (Tensor) –

    Adjacency matrix tensor of shape (batch_size, bag_size, bag_size).

Returns:

  • g ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, ...).