Skip to content

Toy dataset

torchmil.datasets.ToyDataset

Bases: Dataset

This class represents a synthetic dataset for Multiple Instance Learning (MIL) tasks. It generates synthetic bags of instances from a given dataset, where each bag is labeled based on the presence or absence of specific "positive" instances. This class is particularly useful for simulating MIL scenarios, where the goal is to learn from bags of instances rather than individual data points.

Bag generation. The dataset generates bags by sampling instances from the input (data, labels) pair. A bag is labeled as positive if it contains at least one instance from a predefined set of positive labels (obj_labels). The probability of generating a positive bag can be controlled via the argument pos_class_prob. The size of each bag can be specified using the argument bag_size. Additionally, each bag includes instance-level labels, indicating whether individual instances belong to the positive class.

Each bag is returned as a dictionary (TensorDict) with the following keys:

  • X: The bag's feature matrix of shape (bag_size, num_features).
  • Y: The bag's label (1 for positive, 0 for negative).
  • y_inst: The instance-level labels within the bag.

MNIST example. We can create a MIL dataset from the original MNIST as follows:

import torch
from torchvision import datasets, transforms

# Load MNIST dataset
mnist_train = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())

# Extract features and labels
data = mnist_train.data.numpy().reshape(-1, 28*28) / 255
labels = mnist_train.targets.numpy()

# Define positive labels
obj_labels = [1, 2] # Digits 1 and 2 are considered positive

# Create MIL dataset
toy_dataset = ToyDataset(data, labels, num_bags=1000, obj_labels=obj_labels, bag_size=10)

# Retrieve a bag
bag = toy_dataset[0]
X, Y, y_inst = bag['X'], bag['Y'], bag['y_inst']
__init__(data, labels, num_bags, obj_labels, bag_size, pos_class_prob=0.5, seed=0)

ToyMIL dataset class constructor.

Parameters:

  • data (ndarray) –

    Data matrix of shape (num_instances, num_features).

  • labels (ndarray) –

    Labels vector of shape (num_instances,).

  • num_bags (int) –

    Number of bags to generate.

  • obj_labels (list[int]) –

    List of labels to consider as positive.

  • bag_size (Union[int, tuple[int, int]]) –

    Number of instances per bag. If a tuple (min_size, max_size) is provided, the bag size is sampled uniformly from this range.

  • pos_class_prob (float, default: 0.5 ) –

    Probability of generating a positive bag.

  • seed (int, default: 0 ) –

    Random seed.

__getitem__(index)

Parameters:

  • index (int) –

    Index of the bag to retrieve.

Returns:

  • bag_dict ( TensorDict ) –

    Dictionary containing the following keys:

    • X: Bag features of shape (bag_size, feat_dim).
    • Y: Label of the bag.
    • y_inst: Instance labels of the bag.