Skip to content

Datasets

torchmil provides a framework to instantiate datasets for Multiple Instance Learning (MIL) problems. It allows users to create custom datasets that suit their specific needs. In addition, torchmil includes some pre-defined datasets that can be used directly. These correspond to popular benchmark datasets in the field of MIL, such as Camelyon16. See torchmil.datasets for a complete list of the datasets available in torchmil.

In the following, we explain the logic behind the design of datasets in torchmil, the required data and folder structure, and how to create your own dataset. We will also provide a simple example of how to use the ProcessedMILDataset class to create a custom dataset.

Data representation in torchmil

Take a look at the data representation example example to see how the data is represented in torchmil.

In torchmil, bags are represented as a TensorDict object with at least the following properties:

  • bag['X']: it is usually called bag feature matrix, since it represents feature vectors extracted from the raw representation of the instances.
  • bag['Y']: it represents the label of the bag.

Additionally, a bag may contain other properties. The most common ones are:

  • bag['y_inst']: it contains the labels of the instances in the bag.
  • bag['adj']: it contains the adjacency matrix of the bag, which represents the relationships between the instances in the bag.
  • bag['coords']: it contains the coordinates of the instances in the bag, which represent the absolute position of the instances in the bag.

The ProcessedMILDataset class

The ProcessedMILDataset class allows for efficient loading and processing of large datasets. To enable this, it expects each bag to have been pre-processed, saving its properties in separate files: - A feature file should yield an array of shape (bag_size, ...), where ... represents the shape of the features. - A label file should yield an array of shape arbitrary shape, e.g., (1,) for binary classification. - An instance label file should yield an array of shape (bag_size, ...), where ... represents the shape of the instance labels. - A coordinates file should yield an array of shape (bag_size, coords_dim), where coords_dim is the dimension of the coordinates.

The path to these properties should be specified in the __init__ method of the ProcessedMILDataset class. To illustrate this behaviour, let's load the CAMELYON16 dataset:

from torchmil.datasets import ProcessedMILDataset

dataset_dir = "/data/datasets/CAMELYON16"
features_path = "/data/datasets/CAMELYON16/patches_512_preset/features_UNI/"
labels_path = "/data/datasets/CAMELYON16/patches_512_preset/labels/"
inst_labels_path = "/data/datasets/CAMELYON16/patches_512_preset/patch_labels/"
coords_path = "/data/datasets/CAMELYON16/patches_512_preset/coords/"

dataset = ProcessedMILDataset(
    features_path=features_path,
    labels_path=labels_path,
    inst_labels_path=inst_labels_path,
    coords_path=coords_path,
)

print(f"Number of bags: {len(dataset)}")
Number of bags: 399

As you can see, we have specified the path to the properties of each bag. The ProcessedMILDataset class will load the properties of each bag from the specified files assuming the following structure:

features_path/
├── bag1.npy
├── bag2.npy
└── ...
labels_path/
├── bag1.npy
├── bag2.npy
└── ...
inst_labels_path/
├── bag1.npy
├── bag2.npy
└── ...
coords_path/
├── bag1.npy
├── bag2.npy
└── ...
Let's take a look at one of the bags:

bag = dataset[0]
for key in bag.keys():
    print(f"{key}: {bag[key].shape}")
X: torch.Size([1983, 1024])
Y: torch.Size([1])
y_inst: torch.Size([1983])
adj: torch.Size([1983, 1983])
coords: torch.Size([1983, 2])

When the __getitem__ method is called, the ProcessedMILDataset class builds the bag. First, it loads the properties of the bag from the specified files. Then, if the coordinates have been provided, it builds the adjacency matrix of the bag (see the documentation for more details). Finally, it creates a TensorDict object with the properties of the bag. The __getitem__ method then returns the TensorDict object with the properties of the bag.

We can choose which bags we want to load using the bag_names argument:

bag_names = ["test_001", "test_002"]
dataset = ProcessedMILDataset(
    features_path=features_path,
    labels_path=labels_path,
    inst_labels_path=inst_labels_path,
    coords_path=coords_path,
    bag_names=bag_names,
)
print(f"Number of bags: {len(dataset)}")
bag = dataset[0]
for key in bag.keys():
    print(f"{key}: {bag[key].shape}")
Number of bags: 2
X: torch.Size([12255, 1024])
Y: torch.Size([1])
y_inst: torch.Size([12255])
adj: torch.Size([12255, 12255])
coords: torch.Size([12255, 2])

We can choose which properties we want to load using the bag_keys argument. For example, if we want to load only the features and the labels of the bags, we can do it as follows:

dataset = ProcessedMILDataset(
    features_path=features_path,
    labels_path=labels_path,
    inst_labels_path=inst_labels_path,
    coords_path=coords_path,
    bag_keys=["X", "Y"],
)
print(f"Number of bags: {len(dataset)}")
bag = dataset[0]
for key in bag.keys():
    print(f"{key}: {bag[key].shape}")
Number of bags: 399
X: torch.Size([1983, 1024])
Y: torch.Size([1])

Feel free to see all the options in the documentation.

Extending the ProcessedMILDataset class.

The ProcessedMILDataset can be extended to add custom functionalities. One example is the BinaryClassificationDataset class, which is a subclass of ProcessedMILDataset that is tailored for binary classification tasks. It assumes that the bag label \(Y\) and the instance labels \(\left\{ y_1, \ldots, y_N \right\}\) are binary values, i.e., they can take values in \(\left\{ 0, 1 \right\}\). The class also assumes that the bag label is the maximum of the instance labels, i.e.,

\[ \begin{gather} Y = \max \left\{ y_1, \ldots, y_N \right\}. \end{gather} \]

Let's take a look at the implementation to illustrate how to extend the ProcessedMILDataset class. The BinaryClassificationDataset class is implemented as follows:

import torch
import numpy as np
import warnings


class BinaryClassificationDataset(ProcessedMILDataset):
    def __init__(
        self,
        features_path: str,
        labels_path: str,
        inst_labels_path: str = None,
        coords_path: str = None,
        bag_names: list = None,
        bag_keys: list = ["X", "Y", "y_inst", "adj", "coords"],
        dist_thr: float = 1.5,
        adj_with_dist: bool = False,
        norm_adj: bool = True,
        load_at_init: bool = True,
    ) -> None:
        super().__init__(
            features_path=features_path,
            labels_path=labels_path,
            inst_labels_path=inst_labels_path,
            coords_path=coords_path,
            bag_names=bag_names,
            bag_keys=bag_keys,
            dist_thr=dist_thr,
            adj_with_dist=adj_with_dist,
            norm_adj=norm_adj,
            load_at_init=load_at_init,
        )

    def _fix_inst_labels(self, inst_labels):
        """
        Make sure that instance labels have shape (bag_size,).
        """
        if inst_labels is not None:
            while inst_labels.ndim > 1:
                inst_labels = np.squeeze(inst_labels, axis=-1)
        return inst_labels

    def _fix_labels(self, labels):
        """
        Make sure that labels have shape ().
        """
        labels = np.squeeze(labels)
        return labels

    def _load_inst_labels(self, name):
        inst_labels = super()._load_inst_labels(name)
        inst_labels = self._fix_inst_labels(inst_labels)
        return inst_labels

    def _load_labels(self, name):
        labels = super()._load_labels(name)
        labels = self._fix_labels(labels)
        return labels

    def _consistency_check(self, bag_dict, name):
        """
        Check if the instance labels are consistent with the bag label.
        """
        if "Y" in bag_dict:
            if "y_inst" in bag_dict:
                if bag_dict["Y"] != (bag_dict["y_inst"]).max():
                    msg = f"Instance labels (max(y_inst)={(bag_dict['y_inst']).max()}) are not consistent with bag label (Y={bag_dict['Y']}) for bag {name}. Setting all instance labels to -1 (unknown)."
                    warnings.warn(msg)
                    bag_dict["y_inst"] = np.full((bag_dict["X"].shape[0],), -1)
            else:
                if bag_dict["Y"] == 0:
                    bag_dict["y_inst"] = np.zeros(bag_dict["X"].shape[0])
                else:
                    msg = (
                        f"Instance labels not found for bag {name}. Setting all to -1."
                    )
                    warnings.warn(msg)
                    bag_dict["y_inst"] = np.full((bag_dict["X"].shape[0],), -1)
        return bag_dict

    def _load_bag(self, name: str) -> dict[str, torch.Tensor]:
        bag_dict = super()._load_bag(name)
        bag_dict = self._consistency_check(bag_dict, name)
        return bag_dict

As you can see, we have added explicit comprobations to ensure that the above conditions are fullfilled. If they are not, a warning is shown on the output stream. All we need to do is to override the corresponding methods to add the desired functionality.

Creating your own dataset

Finally, let's implement a custom dataset. For this, we will use the WSIDataset class, which assumes that the bags are Whole Slide Images (WSIs). It also gives the coordinates of the patches (coords) a special treatment, normalizing their values.

We are going to use the slides from the Genotype-Tissue Expression (GTEx) Project, which can be downloaded for free. Particularly, we will use slides of UrinaryBladder tissue.

To create the dataset, we must first extract the coords of the patches from the original .tiff files and then extract features from those patches. To achieve that, a tool like CLAM can be used. We will assume that no fine-grained annotations, so we will not have access to labels or inst_labels. We have extracted the features using the foundation model UNI.

Then, creating the dataset is as simple as defining a new class that extends WSIDataset:

from torchmil.datasets import WSIDataset
from torchmil.utils.common import read_csv, keep_only_existing_files


class GTExUrinaryBladderDataset(WSIDataset):
    def __init__(
        self,
        root: str,
        features: str = "UNI",
        bag_keys: list = ["X", "adj", "coords"],
        patch_size: int = 512,
        adj_with_dist: bool = False,
        norm_adj: bool = True,
        load_at_init: bool = True,
    ) -> None:
        features_path = f"{root}/patches_{patch_size}/features/features_{features}/"
        labels_path = f"{root}/patches_{patch_size}/labels/"
        patch_labels_path = f"{root}/patches_{patch_size}/inst_labels/"
        coords_path = f"{root}/patches_{patch_size}/coords/"

        # This csv is generated by CLAM, with slide_id containing "bag_name.format"
        bag_names_file = f"{root}/patches_{patch_size}/process_list_autogen.csv"
        dict_list = read_csv(bag_names_file)
        wsi_names = list(set([row["slide_id"].split(".")[0] for row in dict_list]))
        wsi_names = keep_only_existing_files(features_path, wsi_names)

        WSIDataset.__init__(
            self,
            features_path=features_path,
            labels_path=labels_path,
            patch_labels_path=patch_labels_path,
            coords_path=coords_path,
            wsi_names=wsi_names,
            bag_keys=bag_keys,
            patch_size=patch_size,
            adj_with_dist=adj_with_dist,
            norm_adj=norm_adj,
            load_at_init=load_at_init,
        )

We have now defined our new GTExUrinaryBladderDataset class. We can now instantiate it, using as bag_keys only the features X and the adjacency matrix adj. We only have to specify the root path! We will use load_at_init = False so that the features of the slides are only loaded when needed.

# This is my root, change it to your own!
root = "/data/data_fjaviersaezm/GTExTorchmil/UrinaryBladder/"
dataset = GTExUrinaryBladderDataset(
    root=root, features="UNI", bag_keys=["X", "adj"], patch_size=512
)
print(dataset.bag_names[:3])
['GTEX-N7MS-2125', 'GTEX-N7MT-1825', 'GTEX-NFK9-2125']

Great! The dataset object initialized without problems. Now we can display a bag, which is returned as a dict.

el = dataset[0]
for key in el.keys():
    print(f"{key}: {el[key].shape}")
X: torch.Size([825, 1024])
adj: torch.Size([825, 825])

Nice! The dataset has correctly loaded the X tensor and has built the adjacency matrix adj. We can already use this bag as input for a MIL model!