torch_em.data.datasets.histopathology.pannuke

  1import os
  2import h5py
  3import vigra
  4import shutil
  5import numpy as np
  6from glob import glob
  7from typing import List
  8
  9import torch_em
 10from torch_em.data.datasets import util
 11
 12
 13# PanNuke Dataset - https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke
 14URLS = {
 15    "fold_1": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_1.zip",
 16    "fold_2": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_2.zip",
 17    "fold_3": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_3.zip"
 18}
 19
 20
 21CHECKSUM = {
 22    "fold_1": "6e19ad380300e8ce9480f9ab6a14cc91fa4b6a511609b40e3d70bdf9c881ed0b",
 23    "fold_2": "5bc540cc509f64b5f5a274d6e5a245527dbd3e6d3155d43555115c5d54709b07",
 24    "fold_3": "c14d372981c42f611ebc80afad01702b89cad8c1b3089daa31931cf5a4b1a39d"
 25}
 26
 27
 28def _download_pannuke_dataset(path, download, folds):
 29    os.makedirs(path, exist_ok=True)
 30
 31    checksum = CHECKSUM
 32
 33    for tmp_fold in folds:
 34        if os.path.exists(os.path.join(path, f"pannuke_{tmp_fold}.h5")):
 35            return
 36
 37        util.download_source(os.path.join(path, f"{tmp_fold}.zip"), URLS[tmp_fold], download, checksum[tmp_fold])
 38
 39        print(f"Unzipping the PanNuke dataset in {tmp_fold} directories...")
 40        util.unzip(os.path.join(path, f"{tmp_fold}.zip"), os.path.join(path, f"{tmp_fold}"), True)
 41
 42        _convert_to_hdf5(path, tmp_fold)
 43
 44
 45def _convert_to_hdf5(path, fold):
 46    """Here, we create the h5 files from the input data into 4 essentials (keys):
 47        - "images" - the raw input images (transposed into the expected format) (S x 3 x H x W)
 48        - "labels/masks" - the raw input masks (transposed as above) (S x 6 x H x W)
 49        - "labels/instances" - the converted all-instance labels (S x H x W)
 50        - "labels/semantic" - the converted semantic labels (S x H x W)
 51            - where, the semantic instance representation is as follows:
 52                (0: Background, 1: Neoplastic cells, 2: Inflammatory,
 53                 3: Connective/Soft tissue cells, 4: Dead Cells, 5: Epithelial)
 54    """
 55    if os.path.exists(os.path.join(path, f"pannuke_{fold}.h5")):
 56        return
 57
 58    print(f"Converting {fold} into h5 file format...")
 59    img_paths = glob(os.path.join(path, "**", "images.npy"), recursive=True)
 60    gt_paths = glob(os.path.join(path, "**", "masks.npy"), recursive=True)
 61
 62    for img_path, gt_path in zip(img_paths, gt_paths):
 63        # original (raw) shape : S x H x W x C -> transposed shape (expected) : C x S x H x W
 64        img = np.load(img_path)
 65        labels = np.load(gt_path)
 66
 67        instances = _channels_to_instances(labels)
 68        semantic = _channels_to_semantics(labels)
 69
 70        img = img.transpose(3, 0, 1, 2)
 71        labels = labels.transpose(3, 0, 1, 2)
 72
 73        # img.shape -> (3, 2656, 256, 256) --- img_chunks -> (3, 1, 256, 256)
 74        # (same logic as above for labels)
 75        img_chunks = (img.shape[0], 1) + img.shape[2:]
 76        label_chunks = (labels.shape[0], 1) + labels.shape[2:]
 77        other_label_chunks = (1,) + labels.shape[2:]  # for instance and semantic labels
 78
 79        with h5py.File(os.path.join(path, f"pannuke_{fold}.h5"), "w") as f:
 80            f.create_dataset("images", data=img, compression="gzip", chunks=img_chunks)
 81            f.create_dataset("labels/masks", data=labels, compression="gzip", chunks=label_chunks)
 82            f.create_dataset("labels/instances", data=instances, compression="gzip", chunks=other_label_chunks)
 83            f.create_dataset("labels/semantic", data=semantic, compression="gzip", chunks=other_label_chunks)
 84
 85    dir_to_rm = glob(os.path.join(path, "*[!.h5]"))
 86    for tmp_dir in dir_to_rm:
 87        shutil.rmtree(tmp_dir)
 88
 89
 90def _channels_to_instances(labels):
 91    """Converting the ground-truth of 6 (instance) channels into 1 label with instances from all channels
 92    channel info -
 93    (0: Neoplastic cells, 1: Inflammatory, 2: Connective/Soft tissue cells, 3: Dead Cells, 4: Epithelial, 6: Background)
 94
 95    Returns:
 96        - instance labels of dimensions -> (C x H x W)
 97    """
 98    labels = labels.transpose(0, 3, 1, 2)  # to access with the shape S x 6 x H x W
 99    list_of_instances = []
100
101    for label_slice in labels:  # access the slices (each with 6 channels of H x W labels)
102        segmentation = np.zeros(labels.shape[2:])
103        max_ids = []
104        for label_channel in label_slice[:-1]:  # access the channels
105            # the 'start_label' takes care of where to start allocating the instance ids from
106            this_labels, max_id, _ = vigra.analysis.relabelConsecutive(
107                label_channel.astype("uint64"),
108                start_label=max_ids[-1] + 1 if len(max_ids) > 0 else 1)
109
110            # some trailing channels might not have labels, hence appending only for elements with RoIs
111            if max_id > 0:
112                max_ids.append(max_id)
113
114            segmentation[this_labels > 0] = this_labels[this_labels > 0]
115
116        list_of_instances.append(segmentation)
117
118    f_segmentation = np.stack(list_of_instances)
119
120    return f_segmentation
121
122
123def _channels_to_semantics(labels):
124    """Converting the ground-truth of 6 (instance) channels  into semantic labels, ollowing below the id info as:
125    (1 -> Neoplastic cells, 2 -> Inflammatory, 3 -> Connective/Soft tissue cells,
126    4 -> Dead Cells, 5 -> Epithelial, 0 -> Background)
127
128    Returns:
129        - semantic labels of dimensions -> (C x H x W)
130    """
131    labels = labels.transpose(0, 3, 1, 2)
132    list_of_semantic = []
133
134    for label_slice in labels:
135        segmentation = np.zeros(labels.shape[2:])
136        for i, label_channel in enumerate(label_slice[:-1]):
137            segmentation[label_channel > 0] = i + 1
138        list_of_semantic.append(segmentation)
139
140    f_segmentation = np.stack(list_of_semantic)
141
142    return f_segmentation
143
144
145def get_pannuke_dataset(
146        path,
147        patch_shape,
148        folds: List[str] = ["fold_1", "fold_2", "fold_3"],
149        rois={},
150        download=False,
151        with_channels=True,
152        with_label_channels=False,
153        custom_label_choice: str = "instances",
154        **kwargs
155):
156    assert custom_label_choice in [
157        "masks", "instances", "semantic"
158    ], "Select the type of labels you want from [masks/instances/semantic] (See `_convert_to_hdf5` for details)"
159
160    if rois is not None:
161        assert isinstance(rois, dict)
162
163    _download_pannuke_dataset(path, download, folds)
164
165    data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds]
166    data_rois = [rois.get(fold, np.s_[:, :, :]) for fold in folds]
167
168    raw_key = "images"
169    label_key = f"labels/{custom_label_choice}"
170
171    return torch_em.default_segmentation_dataset(
172        data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois,
173        with_channels=with_channels, with_label_channels=with_label_channels, **kwargs
174    )
175
176
177def get_pannuke_loader(
178        path,
179        patch_shape,
180        batch_size,
181        folds=["fold_1", "fold_2", "fold_3"],
182        download=False,
183        rois={},
184        custom_label_choice="instances",
185        **kwargs
186):
187    """TODO
188    """
189    dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
190
191    ds = get_pannuke_dataset(
192        path=path,
193        patch_shape=patch_shape,
194        folds=folds,
195        rois=rois,
196        download=download,
197        custom_label_choice=custom_label_choice,
198        **dataset_kwargs)
199    return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
URLS = {'fold_1': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_1.zip', 'fold_2': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_2.zip', 'fold_3': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_3.zip'}
CHECKSUM = {'fold_1': '6e19ad380300e8ce9480f9ab6a14cc91fa4b6a511609b40e3d70bdf9c881ed0b', 'fold_2': '5bc540cc509f64b5f5a274d6e5a245527dbd3e6d3155d43555115c5d54709b07', 'fold_3': 'c14d372981c42f611ebc80afad01702b89cad8c1b3089daa31931cf5a4b1a39d'}
def get_pannuke_dataset( path, patch_shape, folds: List[str] = ['fold_1', 'fold_2', 'fold_3'], rois={}, download=False, with_channels=True, with_label_channels=False, custom_label_choice: str = 'instances', **kwargs):
146def get_pannuke_dataset(
147        path,
148        patch_shape,
149        folds: List[str] = ["fold_1", "fold_2", "fold_3"],
150        rois={},
151        download=False,
152        with_channels=True,
153        with_label_channels=False,
154        custom_label_choice: str = "instances",
155        **kwargs
156):
157    assert custom_label_choice in [
158        "masks", "instances", "semantic"
159    ], "Select the type of labels you want from [masks/instances/semantic] (See `_convert_to_hdf5` for details)"
160
161    if rois is not None:
162        assert isinstance(rois, dict)
163
164    _download_pannuke_dataset(path, download, folds)
165
166    data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds]
167    data_rois = [rois.get(fold, np.s_[:, :, :]) for fold in folds]
168
169    raw_key = "images"
170    label_key = f"labels/{custom_label_choice}"
171
172    return torch_em.default_segmentation_dataset(
173        data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois,
174        with_channels=with_channels, with_label_channels=with_label_channels, **kwargs
175    )
def get_pannuke_loader( path, patch_shape, batch_size, folds=['fold_1', 'fold_2', 'fold_3'], download=False, rois={}, custom_label_choice='instances', **kwargs):
178def get_pannuke_loader(
179        path,
180        patch_shape,
181        batch_size,
182        folds=["fold_1", "fold_2", "fold_3"],
183        download=False,
184        rois={},
185        custom_label_choice="instances",
186        **kwargs
187):
188    """TODO
189    """
190    dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
191
192    ds = get_pannuke_dataset(
193        path=path,
194        patch_shape=patch_shape,
195        folds=folds,
196        rois=rois,
197        download=download,
198        custom_label_choice=custom_label_choice,
199        **dataset_kwargs)
200    return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

TODO