Datasets
torchmil provides a framework to instantiate datasets for Multiple Instance Learning (MIL) problems. It allows users to create custom datasets that suit their specific needs. In addition, torchmil includes some pre-defined datasets that can be used directly. These correspond to popular benchmark datasets in the field of MIL, such as Camelyon16. See torchmil.datasets for a complete list of the datasets available in torchmil.
In the following, we explain the logic behind the design of datasets in torchmil, the required data and folder structure, and how to create your own dataset. We will also provide a simple example of how to use the ProcessedMILDataset
class to create a custom dataset.
Data representation in torchmil
Take a look at the data representation example example to see how the data is represented in torchmil.
In torchmil, bags are represented as a TensorDict
object with at least the following properties:
bag['X']
: it is usually called bag feature matrix, since it represents feature vectors extracted from the raw representation of the instances.bag['Y']
: it represents the label of the bag.
Additionally, a bag may contain other properties. The most common ones are:
bag['y_inst']
: it contains the labels of the instances in the bag.bag['adj']
: it contains the adjacency matrix of the bag, which represents the relationships between the instances in the bag.bag['coords']
: it contains the coordinates of the instances in the bag, which represent the absolute position of the instances in the bag.
The ProcessedMILDataset
class
The ProcessedMILDataset class allows for efficient loading and processing of large datasets. To enable this, it expects each bag to have been pre-processed, saving its properties in separate files:
- A feature file should yield an array of shape (bag_size, ...)
, where ...
represents the shape of the features.
- A label file should yield an array of shape arbitrary shape, e.g., (1,)
for binary classification.
- An instance label file should yield an array of shape (bag_size, ...)
, where ...
represents the shape of the instance labels.
- A coordinates file should yield an array of shape (bag_size, coords_dim)
, where coords_dim
is the dimension of the coordinates.
The path to these properties should be specified in the __init__
method of the ProcessedMILDataset
class. To illustrate this behaviour, let's load the CAMELYON16 dataset:
from torchmil.datasets import ProcessedMILDataset
dataset_dir = "/data/datasets/CAMELYON16"
features_path = "/data/datasets/CAMELYON16/patches_512_preset/features_UNI/"
labels_path = "/data/datasets/CAMELYON16/patches_512_preset/labels/"
inst_labels_path = "/data/datasets/CAMELYON16/patches_512_preset/patch_labels/"
coords_path = "/data/datasets/CAMELYON16/patches_512_preset/coords/"
dataset = ProcessedMILDataset(
features_path=features_path,
labels_path=labels_path,
inst_labels_path=inst_labels_path,
coords_path=coords_path,
)
print(f"Number of bags: {len(dataset)}")
As you can see, we have specified the path to the properties of each bag. The ProcessedMILDataset
class will load the properties of each bag from the specified files assuming the following structure:
features_path/
├── bag1.npy
├── bag2.npy
└── ...
labels_path/
├── bag1.npy
├── bag2.npy
└── ...
inst_labels_path/
├── bag1.npy
├── bag2.npy
└── ...
coords_path/
├── bag1.npy
├── bag2.npy
└── ...
bag = dataset[0]
for key in bag.keys():
print(f"{key}: {bag[key].shape}")
When the __getitem__
method is called, the ProcessedMILDataset
class builds the bag. First, it loads the properties of the bag from the specified files. Then, if the coordinates have been provided, it builds the adjacency matrix of the bag (see the documentation for more details). Finally, it creates a TensorDict
object with the properties of the bag. The __getitem__
method then returns the TensorDict
object with the properties of the bag.
We can choose which bags we want to load using the bag_names
argument:
bag_names = ["test_001", "test_002"]
dataset = ProcessedMILDataset(
features_path=features_path,
labels_path=labels_path,
inst_labels_path=inst_labels_path,
coords_path=coords_path,
bag_names=bag_names,
)
print(f"Number of bags: {len(dataset)}")
bag = dataset[0]
for key in bag.keys():
print(f"{key}: {bag[key].shape}")
We can choose which properties we want to load using the bag_keys
argument. For example, if we want to load only the features and the labels of the bags, we can do it as follows:
dataset = ProcessedMILDataset(
features_path=features_path,
labels_path=labels_path,
inst_labels_path=inst_labels_path,
coords_path=coords_path,
bag_keys=["X", "Y"],
)
print(f"Number of bags: {len(dataset)}")
bag = dataset[0]
for key in bag.keys():
print(f"{key}: {bag[key].shape}")
Feel free to see all the options in the documentation.
Extending the ProcessedMILDataset
class.
The ProcessedMILDataset
can be extended to add custom functionalities. One example is the BinaryClassificationDataset
class, which is a subclass of ProcessedMILDataset
that is tailored for binary classification tasks. It assumes that the bag label \(Y\) and the instance labels \(\left\{ y_1, \ldots, y_N \right\}\) are binary values, i.e., they can take values in \(\left\{ 0, 1 \right\}\). The class also assumes that the bag label is the maximum of the instance labels, i.e.,
Let's take a look at the implementation to illustrate how to extend the ProcessedMILDataset
class. The BinaryClassificationDataset
class is implemented as follows:
import torch
import numpy as np
import warnings
class BinaryClassificationDataset(ProcessedMILDataset):
def __init__(
self,
features_path: str,
labels_path: str,
inst_labels_path: str = None,
coords_path: str = None,
bag_names: list = None,
bag_keys: list = ["X", "Y", "y_inst", "adj", "coords"],
dist_thr: float = 1.5,
adj_with_dist: bool = False,
norm_adj: bool = True,
load_at_init: bool = True,
) -> None:
super().__init__(
features_path=features_path,
labels_path=labels_path,
inst_labels_path=inst_labels_path,
coords_path=coords_path,
bag_names=bag_names,
bag_keys=bag_keys,
dist_thr=dist_thr,
adj_with_dist=adj_with_dist,
norm_adj=norm_adj,
load_at_init=load_at_init,
)
def _fix_inst_labels(self, inst_labels):
"""
Make sure that instance labels have shape (bag_size,).
"""
if inst_labels is not None:
while inst_labels.ndim > 1:
inst_labels = np.squeeze(inst_labels, axis=-1)
return inst_labels
def _fix_labels(self, labels):
"""
Make sure that labels have shape ().
"""
labels = np.squeeze(labels)
return labels
def _load_inst_labels(self, name):
inst_labels = super()._load_inst_labels(name)
inst_labels = self._fix_inst_labels(inst_labels)
return inst_labels
def _load_labels(self, name):
labels = super()._load_labels(name)
labels = self._fix_labels(labels)
return labels
def _consistency_check(self, bag_dict, name):
"""
Check if the instance labels are consistent with the bag label.
"""
if "Y" in bag_dict:
if "y_inst" in bag_dict:
if bag_dict["Y"] != (bag_dict["y_inst"]).max():
msg = f"Instance labels (max(y_inst)={(bag_dict['y_inst']).max()}) are not consistent with bag label (Y={bag_dict['Y']}) for bag {name}. Setting all instance labels to -1 (unknown)."
warnings.warn(msg)
bag_dict["y_inst"] = np.full((bag_dict["X"].shape[0],), -1)
else:
if bag_dict["Y"] == 0:
bag_dict["y_inst"] = np.zeros(bag_dict["X"].shape[0])
else:
msg = (
f"Instance labels not found for bag {name}. Setting all to -1."
)
warnings.warn(msg)
bag_dict["y_inst"] = np.full((bag_dict["X"].shape[0],), -1)
return bag_dict
def _load_bag(self, name: str) -> dict[str, torch.Tensor]:
bag_dict = super()._load_bag(name)
bag_dict = self._consistency_check(bag_dict, name)
return bag_dict
As you can see, we have added explicit comprobations to ensure that the above conditions are fullfilled. If they are not, a warning is shown on the output stream. All we need to do is to override the corresponding methods to add the desired functionality.
Creating your own dataset
Finally, let's implement a custom dataset. For this, we will use the WSIDataset
class, which assumes that the bags are Whole Slide Images (WSIs). It also gives the coordinates of the patches (coords
) a special treatment, normalizing their values.
We are going to use the slides from the Genotype-Tissue Expression (GTEx) Project, which can be downloaded for free. Particularly, we will use slides of UrinaryBladder tissue.
To create the dataset, we must first extract the coords
of the patches from the original .tiff files and then extract features
from those patches. To achieve that, a tool like CLAM can be used. We will assume that no fine-grained annotations, so we will not have access to labels
or inst_labels
. We have extracted the features using the foundation model UNI.
Then, creating the dataset is as simple as defining a new class that extends WSIDataset
:
from torchmil.datasets import WSIDataset
from torchmil.utils.common import read_csv, keep_only_existing_files
class GTExUrinaryBladderDataset(WSIDataset):
def __init__(
self,
root: str,
features: str = "UNI",
bag_keys: list = ["X", "adj", "coords"],
patch_size: int = 512,
adj_with_dist: bool = False,
norm_adj: bool = True,
load_at_init: bool = True,
) -> None:
features_path = f"{root}/patches_{patch_size}/features/features_{features}/"
labels_path = f"{root}/patches_{patch_size}/labels/"
patch_labels_path = f"{root}/patches_{patch_size}/inst_labels/"
coords_path = f"{root}/patches_{patch_size}/coords/"
# This csv is generated by CLAM, with slide_id containing "bag_name.format"
bag_names_file = f"{root}/patches_{patch_size}/process_list_autogen.csv"
dict_list = read_csv(bag_names_file)
wsi_names = list(set([row["slide_id"].split(".")[0] for row in dict_list]))
wsi_names = keep_only_existing_files(features_path, wsi_names)
WSIDataset.__init__(
self,
features_path=features_path,
labels_path=labels_path,
patch_labels_path=patch_labels_path,
coords_path=coords_path,
wsi_names=wsi_names,
bag_keys=bag_keys,
patch_size=patch_size,
adj_with_dist=adj_with_dist,
norm_adj=norm_adj,
load_at_init=load_at_init,
)
We have now defined our new GTExUrinaryBladderDataset
class. We can now instantiate it, using as bag_keys
only the features X
and the adjacency matrix adj
. We only have to specify the root path! We will use load_at_init = False
so that the features of the slides are only loaded when needed.
# This is my root, change it to your own!
root = "/data/data_fjaviersaezm/GTExTorchmil/UrinaryBladder/"
dataset = GTExUrinaryBladderDataset(
root=root, features="UNI", bag_keys=["X", "adj"], patch_size=512
)
print(dataset.bag_names[:3])
Great! The dataset object initialized without problems. Now we can display a bag, which is returned as a dict
.
el = dataset[0]
for key in el.keys():
print(f"{key}: {el[key].shape}")
Nice! The dataset has correctly loaded the X
tensor and has built the adjacency matrix adj
. We can already use this bag as input for a MIL model!