Trainer
torchmil.utils.Trainer
Generic trainer class for training MIL models.
__init__(model, optimizer, metrics_dict={'accuracy': torchmetrics.Accuracy(task='binary')}, obj_metric='accuracy', obj_metric_mode='max', lr_scheduler=None, annealing_scheduler_dict=None, device='cuda', logger=None, early_stop_patience=None, disable_pbar=False, verbose=True)
Parameters:
-
model
(MILModel
) –MIL model to be trained. Must be an instance of MILModel.
-
optimizer
(Optimizer
) –Optimizer for training the model.
-
metrics_dict
(dict[str:Metric]
, default:{'accuracy': Accuracy(task='binary')}
) –Dictionary of metrics to be computed during training. Metrics should be instances of torchmetrics.Metric.
-
obj_metric
(str
, default:'accuracy'
) –Objective metric to be used for early stopping and to track the best model. Must be one of the keys in
metrics_dict
. -
obj_metric_mode
(str
, default:'max'
) –Mode for the objective metric. Must be one of 'max' or 'min'. If 'max', the best model is the one with the highest value of the objective metric. If 'min', the best model is the one with the lowest value of the objective metric.
-
lr_scheduler
(_LRScheduler
, default:None
) –Learning rate scheduler.
-
annealing_scheduler_dict
(dict[str:AnnealingScheduler]
, default:None
) –Dictionary of annealing schedulers for loss coefficients. Keys should be the loss names and values should be instances of AnnealingScheduler.
-
device
(str
, default:'cuda'
) –Device to be used for training.
-
logger
(Logger
, default:None
) –Logger to log metrics. Must have a
log
method. It can be, for example, a Wandb Run. -
early_stop_patience
(int
, default:None
) –Patience for early stopping. If None, early stopping is disabled.
-
disable_pbar
(bool
, default:False
) –Disable progress bar.
train(max_epochs, train_dataloader, val_dataloader=None, test_dataloader=None)
Train the model.
Parameters:
-
max_epochs
(int
) –Maximum number of epochs to train.
-
train_dataloader
(DataLoader
) –Train dataloader.
-
val_dataloader
(DataLoader
, default:None
) –Validation dataloader. If None, the train dataloader is used.
-
test_dataloader
(DataLoader
, default:None
) –Test dataloader. If None, test metrics are not computed.
get_model_state_dict()
Get (a deepcopy of) the state dictionary of the model.
Returns:
-
dict
–State dictionary of the model.
get_best_model_state_dict()
Get the state dictionary of the best model (the model with the best objective metric).
Returns:
-
dict
–State dictionary of the best model.
get_best_model()
_log(metrics)
Log metrics using the logger.
Parameters:
-
metrics
(dict[str:float]
) –Dictionary of metrics to be logged.
_shared_loop(dataloader, epoch=0, mode='train')
Shared training/validation/test loop.
Parameters:
-
dataloader
(DataLoader
) –Dataloader.
-
epoch
(int
, default:0
) –Epoch number.
-
mode
(str
, default:'train'
) –Mode of the loop. Must be one of 'train', 'val', 'test'.