Skip to content

Note

See this notebook for an explanation with examples of how batching is performed in torchmil.

Batches in torchmil

Batching is crucial for training deep learning models. However, in MIL, each bag can be of different size. To solve this, in torchmil, the tensors in the bags are padded to the maximum size of the bags in the batch. A mask tensor is used to indicate which elements of the padded tensors are real instances and which are padding. This mask tensor is used to adjust the behavior of the models to ignore the padding elements (e.g., in the attention mechanism).

Why not use torch.nested?

torch.nested offer a more flexible method for handling bags of varying sizes. However, since the PyTorch API for nested tensors is still in the prototype stage, torchmil currently relies on the padding approach.

The function torchmil.data.collate_fn is used to collate a list of bags into a batch. This function can be used as the collate_fn argument of the PyTorch DataLoader. The function torchmil.data.pad_tensors is used to pad the tensors in the bags.


torchmil.data.collate_fn(batch_list, sparse=True)

Collate function for MIL datasets. Given a list of bags (represented as dictionaries) it pads the tensors in the bag to the same shape. Then, it returns a dictionary representing the batch. The keys in the dictionary are the keys in the bag dictionaries. Additionally, the returned dictionary contains a mask for the padded tensors. This mask is 1 where the tensor is not padded and 0 where the tensor is padded.

Parameters:

  • batch_list (list[dict[str, Tensor]]) –

    List of dictionaries. Each dictionary represents a bag and should contain the same keys. The values can be:

    • Tensors of shape (bag_size, ...). In this case, the tensors are padded to the same shape.
    • Sparse tensors in COO format. In this case, the resulting sparse tensor has shape (batch_size, max_bag_size, max_bag_size), where max_bag_size is the maximum bag size in the batch. If sparse=False, the sparse tensor is converted to a dense tensor.
  • sparse (bool, default: True ) –

    If True, the sparse tensors are returned as sparse tensors. If False, the sparse tensors are converted to dense tensors.

Returns:

  • batch_dict ( TensorDict ) –

    Dictionary with the same keys as the bag dictionaries. The values are tensors of shape (batch_size, max_bag_size, ...) or sparse tensors of shape (batch_size, max_bag_size, max_bag_size). Additionally, the dictionary contains a mask of shape (batch_size, max_bag_size).


torchmil.data.pad_tensors(tensor_list, padding_value=0)

Pads a list of tensors to the same shape and returns a mask.

Parameters:

  • tensor_list (list[Tensor]) –

    List of tensors, each of shape (bag_size, ...).

  • padding_value (int, default: 0 ) –

    Value to pad with.

Returns:

  • padded_tensor ( Tensor ) –

    Padded tensor of shape (batch_size, max_bag_size, ...).

  • mask ( Tensor ) –

    Mask of shape (batch_size, max_bag_size).