Skip to content

General MIL model

torchmil.models.MILModel

Bases: Module

Base class for Multiple Instance Learning (MIL) models in torchmil.

Subclasses should implement the following methods:

  • forward: Forward pass of the model. Accepts bag features (and optionally other arguments) and returns the bag label prediction (and optionally other outputs).
  • compute_loss: Compute inner losses of the model. Accepts bag features (and optionally other arguments) and returns the output of the forward method a dictionary of pairs (loss_name, loss_value). By default, the model has no inner losses, so this dictionary is empty.
  • predict: Predict bag and (optionally) instance labels. Accepts bag features (and optionally other arguments) and returns label predictions (and optionally instance label predictions).
__init__(*args, **kwargs)

Initializes the module.

forward(X, *args, **kwargs)

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

Returns:

  • Y_pred ( Tensor ) –

    Bag label prediction of shape (batch_size,).

compute_loss(Y, X, *args, **kwargs)

Parameters:

  • Y (Tensor) –

    Bag labels of shape (batch_size,).

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

Returns:

  • Y_pred ( Tensor ) –

    Bag label prediction of shape (batch_size,).

  • loss_dict ( dict ) –

    Dictionary containing the loss values.

predict(X, return_inst_pred=False, *args, **kwargs)

Parameters:

  • X (Tensor) –

    Bag features of shape (batch_size, bag_size, ...).

Returns:

  • Y_pred ( Tensor ) –

    Bag label prediction of shape (batch_size,).

  • y_inst_pred ( Tensor ) –

    If return_inst_pred=True, returns instance labels predictions of shape (batch_size, bag_size).


torchmil.models.MILModelWrapper

Bases: MILModel

A wrapper class for MIL models in torchmil. It allows to use all models that inherit from MILModel using a common interface:

model_A = ... # forward accepts arguments 'X', 'adj'
model_B = ... # forward accepts arguments 'X''
model_A_w = MILModelWrapper(model_A)
model_B_w = MILModelWrapper(model_B)

bag = TensorDict({'X': ..., 'adj': ..., ...})
Y_pred_A = model_A_w(bag) # calls model_A.forward(X=bag['X'], adj=bag['adj'])
Y_pred_B = model_B_w(bag) # calls model_B.forward(X=bag['X'])
__init__(model)
forward(bag, **kwargs)

Parameters:

  • bag (TensorDict) –

    Dictionary containing one key for each argument accepted by the model's forward method.

Returns:

  • out ( Any ) –

    Output of the model's forward method.

compute_loss(bag, **kwargs)

Parameters:

  • bag (TensorDict) –

    Dictionary containing one key for each argument accepted by the model's forward method.

Returns:

  • out ( tuple[Any, dict] ) –

    Output of the model's compute_loss method.

predict(bag, **kwargs)

Parameters:

  • bag (TensorDict) –

    Dictionary containing one key for each argument accepted by the model's forward method.

Returns:

  • out ( Any ) –

    Output of the model's predict method.