Skip to content

Sm Transformer


torchmil.nn.transformers.SmTransformerEncoder

Bases: Encoder

A Transformer encoder with the \(\texttt{Sm}\) operator, skip connections and layer normalization.

Given an input bag input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), it computes:

\[\begin{align*} \mathbf{X}^{0} & = \mathbf{X} \\ \mathbf{Z}^{l} & = \mathbf{X}^{l-1} + \texttt{Sm}( \text{SelfAttention}( \text{LayerNorm}(\mathbf{X}^{l-1}) ) ), \quad l = 1, \ldots, L \\ \mathbf{X}^{l} & = \mathbf{Z}^{l} + \text{MLP}(\text{LayerNorm}(\mathbf{Z}^{l})), \quad l = 1, \ldots, L \\ \end{align*}\]

This module outputs \(\text{SmTransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L}\) if add_self=False, and \(\text{SmTransformerEncoder}(\mathbf{X}) = \mathbf{X}^{L} + \mathbf{X}\) if add_self=True.

See Sm for more details on the Sm operator.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, n_layers=4, use_mlp=True, add_self=False, dropout=0.0, sm_alpha='trainable', sm_mode='approx', sm_steps=10)

Class constructor

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim (int, default: None ) –

    Output dimension. If None, out_dim = in_dim.

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • n_layers (int, default: 4 ) –

    Number of layers.

  • use_mlp (bool, default: True ) –

    Whether to use feedforward layer.

  • add_self (bool, default: False ) –

    Whether to add input to output. If True, att_dim must be equal to in_dim.

  • dropout (float, default: 0.0 ) –

    Dropout rate.

forward(X, adj, mask=None, return_att=False)

Forward method.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, bag_size).

Returns:

  • Y ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, in_dim).


torchmil.nn.transformers.SmTransformerLayer

Bases: Layer

One layer of the Transformer encoder with the \(\texttt{Sm}\) operator.

Given an input bag \(\mathbf{X} = \left[ \mathbf{x}_1, \ldots, \mathbf{x}_N \right]^\top \in \mathbb{R}^{N \times D}\), this module computes:

\[\begin{align*} \mathbf{Z} & = \mathbf{X} + \texttt{Sm}( \text{SelfAttention}( \text{LayerNorm}(\mathbf{X}) ) )\\ \mathbf{Y} & = \mathbf{Z} + \text{MLP}(\text{LayerNorm}(\mathbf{Z})), \\ \end{align*}\]

and outputs \(\mathbf{Y}\).

See Sm for more details on the Sm operator.

__init__(in_dim, out_dim=None, att_dim=512, n_heads=4, use_mlp=True, dropout=0.0, sm_alpha='trainable', sm_mode='approx', sm_steps=10)

Class constructor.

Parameters:

  • in_dim (int) –

    Input dimension.

  • out_dim

    Output dimension. If None, out_dim = in_dim.

  • att_dim (int, default: 512 ) –

    Attention dimension.

  • n_heads (int, default: 4 ) –

    Number of heads.

  • use_mlp (bool, default: True ) –

    Whether to use feedforward layer.

  • dropout (float, default: 0.0 ) –

    Dropout rate

  • sm_alpha (float, default: 'trainable' ) –

    Alpha value for the Sm operator.

  • sm_mode (str, default: 'approx' ) –

    Sm mode.

  • sm_steps (int, default: 10 ) –

    Number of steps to approximate the exact Sm operator.

forward(X, adj, mask=None, return_att=False)

Forward method.

Parameters:

  • X (Tensor) –

    Input tensor of shape (batch_size, bag_size, in_dim).

  • adj (Tensor) –

    Adjacency matrix of shape (batch_size, bag_size, bag_size).

  • mask (Tensor, default: None ) –

    Mask tensor of shape (batch_size, bag_size).

  • return_att (bool, default: False ) –

    If True, returns attention weights, of shape (batch_size, n_heads, bag_size, bag_size).

Returns:

  • Y ( Tensor ) –

    Output tensor of shape (batch_size, bag_size, in_dim).