General MIL model
torchmil.models.MILModel
Bases: Module
Base class for Multiple Instance Learning (MIL) models in torchmil.
Subclasses should implement the following methods:
forward
: Forward pass of the model. Accepts bag features (and optionally other arguments) and returns the bag label prediction (and optionally other outputs).compute_loss
: Compute inner losses of the model. Accepts bag features (and optionally other arguments) and returns the output of the forward method a dictionary of pairs (loss_name, loss_value). By default, the model has no inner losses, so this dictionary is empty.predict
: Predict bag and (optionally) instance labels. Accepts bag features (and optionally other arguments) and returns label predictions (and optionally instance label predictions).
__init__(*args, **kwargs)
Initializes the module.
forward(X, *args, **kwargs)
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
.
Returns:
-
Y_pred
(Tensor
) –Bag label prediction of shape
(batch_size,)
.
compute_loss(Y, X, *args, **kwargs)
Parameters:
-
Y
(Tensor
) –Bag labels of shape
(batch_size,)
. -
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
.
Returns:
-
Y_pred
(Tensor
) –Bag label prediction of shape
(batch_size,)
. -
loss_dict
(dict
) –Dictionary containing the loss values.
predict(X, return_inst_pred=False, *args, **kwargs)
Parameters:
-
X
(Tensor
) –Bag features of shape
(batch_size, bag_size, ...)
.
Returns:
-
Y_pred
(Tensor
) –Bag label prediction of shape
(batch_size,)
. -
y_inst_pred
(Tensor
) –If
return_inst_pred=True
, returns instance labels predictions of shape(batch_size, bag_size)
.
torchmil.models.MILModelWrapper
Bases: MILModel
A wrapper class for MIL models in torchmil.
It allows to use all models that inherit from MILModel
using a common interface:
model_A = ... # forward accepts arguments 'X', 'adj'
model_B = ... # forward accepts arguments 'X''
model_A_w = MILModelWrapper(model_A)
model_B_w = MILModelWrapper(model_B)
bag = TensorDict({'X': ..., 'adj': ..., ...})
Y_pred_A = model_A_w(bag) # calls model_A.forward(X=bag['X'], adj=bag['adj'])
Y_pred_B = model_B_w(bag) # calls model_B.forward(X=bag['X'])
__init__(model)
forward(bag, **kwargs)
Parameters:
-
bag
(TensorDict
) –Dictionary containing one key for each argument accepted by the model's
forward
method.
Returns:
-
out
(Any
) –Output of the model's
forward
method.
compute_loss(bag, **kwargs)
Parameters:
-
bag
(TensorDict
) –Dictionary containing one key for each argument accepted by the model's
forward
method.
Returns:
-
out
(tuple[Any, dict]
) –Output of the model's
compute_loss
method.
predict(bag, **kwargs)
Parameters:
-
bag
(TensorDict
) –Dictionary containing one key for each argument accepted by the model's
forward
method.
Returns:
-
out
(Any
) –Output of the model's
predict
method.