Training a MIL model for CT Intracranial Hemorrage Detection
The intracraneal hemorrage (ICH) is a serious life-threatening emergency caused by blood-leackage inside the brain. The presence of ICH is confirmed by radiologists by using a Computed Tomography (CT) scan, which consists of a significant number of slices, each representing a section of the head at a given height.
Training a model to detect ICH in CT scans is a challenging task, as it requires that a team of radiologists manually label each CT scan, indicating the presence of ICH in each slice. This is a time-consuming and expensive process, and it is not always feasible to obtain such detailed annotations for large datasets.
An alternative approach is to use Multiple Instance Learning (MIL), which allows us to train a model using weak labels. In this case, we can use the presence of ICH in the CT scan as a weak label, without requiring detailed annotations for each slice. In the following, we explain how to train a simple Multiple Instance Learning (MIL) model to detect ICH using the torchmil library.
ICH detection as a MIL problem
We treat a CT scan as a bag of instances, where each instance is a slice of the CT scan.
The labels of the slices are \(\mathbf{y} = \left[ y_1, \ldots, y_N \right]^\top \in \{0, 1\}^N\). A slice will be given a positive label (\(y_i = 1\)) if it contains ICH, and a negative label (\(y_i = 0\)) if it does not. The labels of the slices are not available at training time, as they are usually obtained by a team of radiologists who manually annotate the CT scans.
In this case, we have access to the bag labels \(Y \in \{0, 1\}\), which indicate whether the CT scan contains ICH or not. The relation between the instance labels and the bag label is as follows:
This means that if at least one slice in the CT scan has hemorrhage (i.e., \(y_i = 1\)), then the bag is labeled as positive (\(Y = 1\)). Otherwise, the bag is labeled as negative (\(Y = 0\)). This is a typical setting for MIL, where we have access to weak labels (the bag labels) but not to the instance labels.
The data
For this tutorial, we will use the RSNA dataset, which can be found in Kaggle. The dataset is composed of a set of CT scans, each containing a number of slices. Each slice is a 2D image, and the CT scan is a 3D volume. As part of torchmil, we have published a version of this dataset adapted for MIL, see the Huggingface repository.
Let us first visualize the data using the torchmil.visualize.vis_ctscan
module. In the following, we first load the slices of a bag and then use the function slices_to_canvas
to create a canvas with all the images and we draw a small contour on each of the slices in the canvas using draw_slices_contour
.
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# torchmil's facilities
from torchmil.visualize.vis_ctscan import slices_to_canvas, draw_slices_contour
CSV_PATH = "/data/datasets/RSNA_ICH/bags_train.csv"
IMG_PATH = "/data/datasets/RSNA_ICH/original/"
# Randomly select a bag
df = pd.read_csv(CSV_PATH)
bag_names = df["bag_name"].unique()
bag_name = bag_names[np.random.randint(0, len(bag_names))] # Randomly select a bag
bag_df = df[df["bag_name"] == bag_name].sort_values("order")
inst_names = bag_df["instance_name"].values
inst_labels = bag_df["instance_label"].values
inst_names_list = [inst_name.split(".")[0] for inst_name in inst_names]
inst_imgs = [np.load(IMG_PATH + inst_name + ".npy") for inst_name in inst_names_list]
print("This scan has {} slices".format(len(inst_imgs)))
# Using torchmil's functions
canvas = slices_to_canvas(inst_imgs, 512)
canvas_contours = draw_slices_contour(canvas, slice_size=512, contour_prop=0.05)
fig, ax = plt.subplots(figsize=(30, 20))
ax.imshow((canvas_contours * 255).astype(np.uint8))
ax.set_xticks([])
ax.set_yticks([])
plt.show()
In practice, training a MIL model directly on the slices is computationally intractable. Due to this limitation, MIL models usually operate on pre-computed features extracted from each of the instances. Although torchmil allows to define models that receive the original slices as input, in this tutorial we will use the pre-computed. We have processed the RSNA dataset to be used for MIL binary classification problems. It can be downloaded from here.
We now make use of torchmil.datasets.RSNAMILDataset
to create an object that serves as a torch.utils.data.Dataset
dataset and contains RSNA. You only need to provide the root
path to the processed dataset, and the desired features
and partition
to load. See how simple is to instance the train dataset:
from torchmil.datasets import RSNAMILDataset
from sklearn.model_selection import train_test_split
dataset = RSNAMILDataset(
root="/data/datasets/RSNA_ICH/MIL_processed/",
features="resnet50",
partition="train",
load_at_init=True,
)
# Split the dataset into train and validation sets
bag_labels = dataset.get_bag_labels()
idx = list(range(len(bag_labels)))
val_prop = 0.2
idx_train, idx_val = train_test_split(
idx, test_size=val_prop, random_state=1234, stratify=bag_labels
)
train_dataset = dataset.subset(idx_train)
val_dataset = dataset.subset(idx_val)
test_dataset = RSNAMILDataset(
root="/data/datasets/RSNA_ICH/MIL_processed/",
features="resnet50",
partition="test",
load_at_init=True,
)
In torchmil, each bag is a TensorDict
. The different keys correspond to different elements of the bag. In this case, each bag has a feature matrix X
, the bag label Y
, the instance coordinates coords
, and the instance labels y_inst
. Recall that the instance labels cannot be used during training, they are available only for evaluation purposes.
bag = train_dataset[0]
print(bag)
Mini-batching of bags
Tipically, the bags in a MIL dataset have different size. This can be a problem when creating mini-batches. To solve this, we use the function collate_fn
from the torchmil.data module. This function creates a mini-batch of bags by padding the bags with zeros to the size of the largest bag in the batch. The function also returns a mask tensor that indicates which instances are real and which are padding.
Why not use torch.nested
?
torch.nested
offer a more flexible method for handling bags of varying sizes. However, since the PyTorch API for nested tensors is still in the prototype stage, torchmil currently relies on the padding approach.
Let's create the dataloaders and visualize the shape of a mini-batch. Since the RSNA dataset does not have many instances per bag, we can use a batch_size
of 64 for the train and validation sets.
from torchmil.data import collate_fn
batch_size = 64
# Create dataloaders
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
)
it = iter(train_dataloader)
batch = next(it)
data_shape = (batch["X"].shape[-1],)
print("Batch: ", batch)
Each batch is again a TensorDict
with an additional key mask
that indicates which instances are real and which are padding. As we can see, the bags are padded to the maximum size of the bags in the batch with zeros. The mask tensor indicates which elements are real instances and which are padding. The function collate_fn
also pads other tensors, such as the adjacency matrix or the instance coordinates.
Training a model in RSNA
We have shown how to load the RSNA dataset for the binary classification task. Now, let us train a MIL model in this dataset! For this example, we will use torchmil implementation of a TransformerABMIL, a version of ABMIL where a Transformer encoder is applied to refine the instances before the Attention Pool. To highlight how simple is to instance a model in torchmil, we will leave all the parameters by default except for the in_shape
, which reflects the data shape. Feel free to check the documentation of Transformer ABMIL to observe the different parameters that this model can be passed.
from torchmil.models import TransformerABMIL
model = TransformerABMIL(in_shape=data_shape)
See? It can not be easier! Now, let's train the model. torchmil offers an easy-to-use trainer class located in torchmil.utils.trainer.Trainer
that provides a generic training for any MIL model. Also, it will show the evolution of the losses and the desired metrics during the epochs.
Note
This Trainer
gives the flexibility to log the results using any wrapped logger
, use annealing for the loss functions via the annealing_scheduler_dict
dictionary, or to set a learning rate scheduler using the parameter lr_scheduler
. Also, you can follow multiple metrics during the training thanks to the parameter metrics_dict
and the integration with the torchmetrics package.
For now, let us just keep it simple and perform a simple training using the torch.optim.Adam
optimizer and training the model for 10 epochs. First, we instance the trainer.
from torchmil.utils.trainer import Trainer
import torchmetrics
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
trainer = Trainer(
model=model,
optimizer=optimizer,
metrics_dict={
"auroc": torchmetrics.AUROC(task="binary").to(device),
"acc": torchmetrics.Accuracy(task="binary").to(device),
},
obj_metric="BCEWithLogitsLoss",
obj_metric_mode="min",
device="cuda",
verbose=False,
)
trainer.train(
max_epochs=10, train_dataloader=train_dataloader, val_dataloader=val_dataloader
)
The loss decreases as the model learns to predict the bag labels. The accuracy increases as the model learns to predict the correct bag labels. This is a good sign that the model is learning!
Evaluating the model
Let's evaluate the model. We are going to compute the accuracy and f1-score on the test set. The accuracy is the proportion of correctly classified bags, while the f1-score is the harmonic mean of precision and recall. The f1-score is a good metric for imbalanced datasets. Typically, in MIL datasets, there are more negative bags than positive bags.
from sklearn.metrics import accuracy_score, f1_score
inst_pred_list = []
y_inst_list = []
Y_pred_list = []
Y_list = []
model.eval()
for batch in test_dataloader:
batch = batch.to(device)
# predict bag label using our model
out = model(batch["X"], batch["mask"])
Y_pred = (out > 0).float()
Y_pred_list.append(Y_pred)
Y_list.append(batch["Y"])
Y_pred = torch.cat(Y_pred_list).cpu().numpy()
Y = torch.cat(Y_list).cpu().numpy()
print(f"test/bag/acc: {accuracy_score(Y_pred, Y)}")
print(f"test/bag/f1: {f1_score(Y_pred, Y)}")
Good! Our model is working well. The accuracy and f1-score are high. And we got this result in less than two minutes on GPU and very few lines of code thanks to torchmil!