Skip to content

Binary classification dataset

torchmil.datasets.BinaryClassificationDataset

Bases: ProcessedMILDataset

Dataset for binary classification MIL problems. See torchmil.datasets.ProcessedMILDataset for more information.

For a given bag with bag label \(Y\) and instance labels \(\left\{ y_1, \ldots, y_N \right \}\), this dataset assumes that

\[\begin{gather} Y \in \left\{ 0, 1 \right\}, \quad y_n \in \left\{ 0, 1 \right\}, \quad \forall n \in \left\{ 1, \ldots, N \right\},\\ Y = \max \left\{ y_1, \ldots, y_N \right\}. \end{gather}\]

When the instance labels are not provided, they are set to 0 if the bag label is 0, and to -1 if the bag label is 1. If the instance labels are provided, but they are not consistent with the bag label, a warning is issued and the instance labels are all set to -1.

__init__(features_path, labels_path, inst_labels_path=None, coords_path=None, bag_names=None, bag_keys=['X', 'Y', 'y_inst', 'adj', 'coords'], dist_thr=1.5, adj_with_dist=False, norm_adj=True, load_at_init=True)
__getitem__(index)

Parameters:

  • index (int) –

    Index of the bag to retrieve.

Returns:

  • bag_dict ( TensorDict ) –

    Dictionary containing the keys defined in bag_keys and their corresponding values.

    • X: Features of the bag, of shape (bag_size, ...).
    • Y: Label of the bag.
    • y_inst: Instance labels of the bag, of shape (bag_size, ...).
    • adj: Adjacency matrix of the bag. It is a sparse COO tensor of shape (bag_size, bag_size). If norm_adj=True, the adjacency matrix is normalized.
    • coords: Coordinates of the bag, of shape (bag_size, coords_dim).