torch_em.data.datasets.histopathology.pannuke
1import os 2import h5py 3import vigra 4import shutil 5import numpy as np 6from glob import glob 7from typing import List 8 9import torch_em 10from torch_em.data.datasets import util 11 12 13# PanNuke Dataset - https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke 14URLS = { 15 "fold_1": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_1.zip", 16 "fold_2": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_2.zip", 17 "fold_3": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_3.zip" 18} 19 20 21CHECKSUM = { 22 "fold_1": "6e19ad380300e8ce9480f9ab6a14cc91fa4b6a511609b40e3d70bdf9c881ed0b", 23 "fold_2": "5bc540cc509f64b5f5a274d6e5a245527dbd3e6d3155d43555115c5d54709b07", 24 "fold_3": "c14d372981c42f611ebc80afad01702b89cad8c1b3089daa31931cf5a4b1a39d" 25} 26 27 28def _download_pannuke_dataset(path, download, folds): 29 os.makedirs(path, exist_ok=True) 30 31 checksum = CHECKSUM 32 33 for tmp_fold in folds: 34 if os.path.exists(os.path.join(path, f"pannuke_{tmp_fold}.h5")): 35 return 36 37 util.download_source(os.path.join(path, f"{tmp_fold}.zip"), URLS[tmp_fold], download, checksum[tmp_fold]) 38 39 print(f"Unzipping the PanNuke dataset in {tmp_fold} directories...") 40 util.unzip(os.path.join(path, f"{tmp_fold}.zip"), os.path.join(path, f"{tmp_fold}"), True) 41 42 _convert_to_hdf5(path, tmp_fold) 43 44 45def _convert_to_hdf5(path, fold): 46 """Here, we create the h5 files from the input data into 4 essentials (keys): 47 - "images" - the raw input images (transposed into the expected format) (S x 3 x H x W) 48 - "labels/masks" - the raw input masks (transposed as above) (S x 6 x H x W) 49 - "labels/instances" - the converted all-instance labels (S x H x W) 50 - "labels/semantic" - the converted semantic labels (S x H x W) 51 - where, the semantic instance representation is as follows: 52 (0: Background, 1: Neoplastic cells, 2: Inflammatory, 53 3: Connective/Soft tissue cells, 4: Dead Cells, 5: Epithelial) 54 """ 55 if os.path.exists(os.path.join(path, f"pannuke_{fold}.h5")): 56 return 57 58 print(f"Converting {fold} into h5 file format...") 59 img_paths = glob(os.path.join(path, "**", "images.npy"), recursive=True) 60 gt_paths = glob(os.path.join(path, "**", "masks.npy"), recursive=True) 61 62 for img_path, gt_path in zip(img_paths, gt_paths): 63 # original (raw) shape : S x H x W x C -> transposed shape (expected) : C x S x H x W 64 img = np.load(img_path) 65 labels = np.load(gt_path) 66 67 instances = _channels_to_instances(labels) 68 semantic = _channels_to_semantics(labels) 69 70 img = img.transpose(3, 0, 1, 2) 71 labels = labels.transpose(3, 0, 1, 2) 72 73 # img.shape -> (3, 2656, 256, 256) --- img_chunks -> (3, 1, 256, 256) 74 # (same logic as above for labels) 75 img_chunks = (img.shape[0], 1) + img.shape[2:] 76 label_chunks = (labels.shape[0], 1) + labels.shape[2:] 77 other_label_chunks = (1,) + labels.shape[2:] # for instance and semantic labels 78 79 with h5py.File(os.path.join(path, f"pannuke_{fold}.h5"), "w") as f: 80 f.create_dataset("images", data=img, compression="gzip", chunks=img_chunks) 81 f.create_dataset("labels/masks", data=labels, compression="gzip", chunks=label_chunks) 82 f.create_dataset("labels/instances", data=instances, compression="gzip", chunks=other_label_chunks) 83 f.create_dataset("labels/semantic", data=semantic, compression="gzip", chunks=other_label_chunks) 84 85 dir_to_rm = glob(os.path.join(path, "*[!.h5]")) 86 for tmp_dir in dir_to_rm: 87 shutil.rmtree(tmp_dir) 88 89 90def _channels_to_instances(labels): 91 """Converting the ground-truth of 6 (instance) channels into 1 label with instances from all channels 92 channel info - 93 (0: Neoplastic cells, 1: Inflammatory, 2: Connective/Soft tissue cells, 3: Dead Cells, 4: Epithelial, 6: Background) 94 95 Returns: 96 - instance labels of dimensions -> (C x H x W) 97 """ 98 labels = labels.transpose(0, 3, 1, 2) # to access with the shape S x 6 x H x W 99 list_of_instances = [] 100 101 for label_slice in labels: # access the slices (each with 6 channels of H x W labels) 102 segmentation = np.zeros(labels.shape[2:]) 103 max_ids = [] 104 for label_channel in label_slice[:-1]: # access the channels 105 # the 'start_label' takes care of where to start allocating the instance ids from 106 this_labels, max_id, _ = vigra.analysis.relabelConsecutive( 107 label_channel.astype("uint64"), 108 start_label=max_ids[-1] + 1 if len(max_ids) > 0 else 1) 109 110 # some trailing channels might not have labels, hence appending only for elements with RoIs 111 if max_id > 0: 112 max_ids.append(max_id) 113 114 segmentation[this_labels > 0] = this_labels[this_labels > 0] 115 116 list_of_instances.append(segmentation) 117 118 f_segmentation = np.stack(list_of_instances) 119 120 return f_segmentation 121 122 123def _channels_to_semantics(labels): 124 """Converting the ground-truth of 6 (instance) channels into semantic labels, ollowing below the id info as: 125 (1 -> Neoplastic cells, 2 -> Inflammatory, 3 -> Connective/Soft tissue cells, 126 4 -> Dead Cells, 5 -> Epithelial, 0 -> Background) 127 128 Returns: 129 - semantic labels of dimensions -> (C x H x W) 130 """ 131 labels = labels.transpose(0, 3, 1, 2) 132 list_of_semantic = [] 133 134 for label_slice in labels: 135 segmentation = np.zeros(labels.shape[2:]) 136 for i, label_channel in enumerate(label_slice[:-1]): 137 segmentation[label_channel > 0] = i + 1 138 list_of_semantic.append(segmentation) 139 140 f_segmentation = np.stack(list_of_semantic) 141 142 return f_segmentation 143 144 145def get_pannuke_dataset( 146 path, 147 patch_shape, 148 folds: List[str] = ["fold_1", "fold_2", "fold_3"], 149 rois={}, 150 download=False, 151 with_channels=True, 152 with_label_channels=False, 153 custom_label_choice: str = "instances", 154 **kwargs 155): 156 assert custom_label_choice in [ 157 "masks", "instances", "semantic" 158 ], "Select the type of labels you want from [masks/instances/semantic] (See `_convert_to_hdf5` for details)" 159 160 if rois is not None: 161 assert isinstance(rois, dict) 162 163 _download_pannuke_dataset(path, download, folds) 164 165 data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds] 166 data_rois = [rois.get(fold, np.s_[:, :, :]) for fold in folds] 167 168 raw_key = "images" 169 label_key = f"labels/{custom_label_choice}" 170 171 return torch_em.default_segmentation_dataset( 172 data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois, 173 with_channels=with_channels, with_label_channels=with_label_channels, **kwargs 174 ) 175 176 177def get_pannuke_loader( 178 path, 179 patch_shape, 180 batch_size, 181 folds=["fold_1", "fold_2", "fold_3"], 182 download=False, 183 rois={}, 184 custom_label_choice="instances", 185 **kwargs 186): 187 """TODO 188 """ 189 dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 190 191 ds = get_pannuke_dataset( 192 path=path, 193 patch_shape=patch_shape, 194 folds=folds, 195 rois=rois, 196 download=download, 197 custom_label_choice=custom_label_choice, 198 **dataset_kwargs) 199 return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
URLS =
{'fold_1': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_1.zip', 'fold_2': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_2.zip', 'fold_3': 'https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_3.zip'}
CHECKSUM =
{'fold_1': '6e19ad380300e8ce9480f9ab6a14cc91fa4b6a511609b40e3d70bdf9c881ed0b', 'fold_2': '5bc540cc509f64b5f5a274d6e5a245527dbd3e6d3155d43555115c5d54709b07', 'fold_3': 'c14d372981c42f611ebc80afad01702b89cad8c1b3089daa31931cf5a4b1a39d'}
def
get_pannuke_dataset( path, patch_shape, folds: List[str] = ['fold_1', 'fold_2', 'fold_3'], rois={}, download=False, with_channels=True, with_label_channels=False, custom_label_choice: str = 'instances', **kwargs):
146def get_pannuke_dataset( 147 path, 148 patch_shape, 149 folds: List[str] = ["fold_1", "fold_2", "fold_3"], 150 rois={}, 151 download=False, 152 with_channels=True, 153 with_label_channels=False, 154 custom_label_choice: str = "instances", 155 **kwargs 156): 157 assert custom_label_choice in [ 158 "masks", "instances", "semantic" 159 ], "Select the type of labels you want from [masks/instances/semantic] (See `_convert_to_hdf5` for details)" 160 161 if rois is not None: 162 assert isinstance(rois, dict) 163 164 _download_pannuke_dataset(path, download, folds) 165 166 data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds] 167 data_rois = [rois.get(fold, np.s_[:, :, :]) for fold in folds] 168 169 raw_key = "images" 170 label_key = f"labels/{custom_label_choice}" 171 172 return torch_em.default_segmentation_dataset( 173 data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois, 174 with_channels=with_channels, with_label_channels=with_label_channels, **kwargs 175 )
def
get_pannuke_loader( path, patch_shape, batch_size, folds=['fold_1', 'fold_2', 'fold_3'], download=False, rois={}, custom_label_choice='instances', **kwargs):
178def get_pannuke_loader( 179 path, 180 patch_shape, 181 batch_size, 182 folds=["fold_1", "fold_2", "fold_3"], 183 download=False, 184 rois={}, 185 custom_label_choice="instances", 186 **kwargs 187): 188 """TODO 189 """ 190 dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 191 192 ds = get_pannuke_dataset( 193 path=path, 194 patch_shape=patch_shape, 195 folds=folds, 196 rois=rois, 197 download=download, 198 custom_label_choice=custom_label_choice, 199 **dataset_kwargs) 200 return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
TODO