torch_em.data.datasets.histopathology.monusac

  1import os
  2import shutil
  3from glob import glob
  4from tqdm import tqdm
  5from pathlib import Path
  6from typing import Optional, List
  7
  8import imageio.v3 as imageio
  9
 10import torch_em
 11from .. import util
 12
 13
 14URL = {
 15    "train": "https://drive.google.com/uc?export=download&id=1lxMZaAPSpEHLSxGA9KKMt_r-4S8dwLhq",
 16    "test": "https://drive.google.com/uc?export=download&id=1G54vsOdxWY1hG7dzmkeK3r0xz9s-heyQ"
 17}
 18
 19
 20CHECKSUM = {
 21    "train": "5b7cbeb34817a8f880d3fddc28391e48d3329a91bf3adcbd131ea149a725cd92",
 22    "test": "bcbc38f6bf8b149230c90c29f3428cc7b2b76f8acd7766ce9fc908fc896c2674"
 23}
 24
 25# here's the description: https://drive.google.com/file/d/1kdOl3s6uQBRv0nToSIf1dPuceZunzL4N/view
 26ORGAN_SPLITS = {
 27    "train": {
 28        "lung": ["TCGA-55-1594", "TCGA-69-7760", "TCGA-69-A59K", "TCGA-73-4668", "TCGA-78-7220",
 29                 "TCGA-86-7713", "TCGA-86-8672", "TCGA-L4-A4E5", "TCGA-MP-A4SY", "TCGA-MP-A4T7"],
 30        "kidney": ["TCGA-5P-A9K0", "TCGA-B9-A44B", "TCGA-B9-A8YI", "TCGA-DW-7841", "TCGA-EV-5903", "TCGA-F9-A97G",
 31                   "TCGA-G7-A8LD", "TCGA-MH-A560", "TCGA-P4-AAVK", "TCGA-SX-A7SR", "TCGA-UZ-A9PO", "TCGA-UZ-A9PU"],
 32        "breast": ["TCGA-A2-A0CV", "TCGA-A2-A0ES", "TCGA-B6-A0WZ", "TCGA-BH-A18T", "TCGA-D8-A1X5",
 33                   "TCGA-E2-A154", "TCGA-E9-A22B", "TCGA-E9-A22G", "TCGA-EW-A6SD", "TCGA-S3-AA11"],
 34        "prostate": ["TCGA-EJ-5495", "TCGA-EJ-5505", "TCGA-EJ-5517", "TCGA-G9-6342", "TCGA-G9-6499",
 35                     "TCGA-J4-A67Q", "TCGA-J4-A67T", "TCGA-KK-A59X", "TCGA-KK-A6E0", "TCGA-KK-A7AW",
 36                     "TCGA-V1-A8WL", "TCGA-V1-A9O9", "TCGA-X4-A8KQ", "TCGA-YL-A9WY"]
 37    },
 38    "test": {
 39        "lung": ["TCGA-49-6743", "TCGA-50-6591", "TCGA-55-7570", "TCGA-55-7573",
 40                 "TCGA-73-4662", "TCGA-78-7152", "TCGA-MP-A4T7"],
 41        "kidney": ["TCGA-2Z-A9JG", "TCGA-2Z-A9JN", "TCGA-DW-7838", "TCGA-DW-7963",
 42                   "TCGA-F9-A8NY", "TCGA-IZ-A6M9", "TCGA-MH-A55W"],
 43        "breast": ["TCGA-A2-A04X", "TCGA-A2-A0ES", "TCGA-D8-A3Z6", "TCGA-E2-A108", "TCGA-EW-A6SB"],
 44        "prostate": ["TCGA-G9-6356", "TCGA-G9-6367", "TCGA-VP-A87E", "TCGA-VP-A87H", "TCGA-X4-A8KS", "TCGA-YL-A9WL"]
 45    },
 46}
 47
 48
 49def _download_monusac(path, download, split):
 50    assert split in ["train", "test"], "Please choose from train/test"
 51
 52    # check if we have extracted the images and labels already
 53    im_path = os.path.join(path, "images", split)
 54    label_path = os.path.join(path, "labels", split)
 55    if os.path.exists(im_path) and os.path.exists(label_path):
 56        return
 57
 58    os.makedirs(path, exist_ok=True)
 59    zip_path = os.path.join(path, f"monusac_{split}.zip")
 60    util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split])
 61
 62    _process_monusac(path, split)
 63
 64    _check_channel_consistency(path, split)
 65
 66
 67def _check_channel_consistency(path, split):
 68    "The provided tif images have RGBA channels, check and remove the alpha channel"
 69    all_image_path = glob(os.path.join(path, "images", split, "*.tif"))
 70    for image_path in all_image_path:
 71        image = imageio.imread(image_path)
 72        assert image.shape[-1] == 4, f"Image has an unexpected shape: {image.shape}"
 73        rgb_image = image[..., :-1]  # get rid of the alpha channel
 74        imageio.imwrite(image_path, rgb_image)
 75
 76
 77def _process_monusac(path, split):
 78    util.unzip(os.path.join(path, f"monusac_{split}.zip"), path)
 79
 80    # assorting the images into expected dir;
 81    # converting the label xml files to numpy arrays (of same dimension as input images) in the expected dir
 82    root_img_save_dir = os.path.join(path, "images", split)
 83    root_label_save_dir = os.path.join(path, "labels", split)
 84
 85    os.makedirs(root_img_save_dir, exist_ok=True)
 86    os.makedirs(root_label_save_dir, exist_ok=True)
 87
 88    all_patient_dir = sorted(glob(os.path.join(path, "MoNuSAC*", "*")))
 89
 90    for patient_dir in tqdm(all_patient_dir, desc=f"Converting {split} inputs for all patients"):
 91        all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))
 92        all_xml_label_dir = sorted(glob(os.path.join(patient_dir, "*.xml")))
 93
 94        if len(all_img_dir) != len(all_xml_label_dir):
 95            _convert_missing_tif_from_svs(patient_dir)
 96            all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))
 97
 98        assert len(all_img_dir) == len(all_xml_label_dir)
 99
100        for img_path, xml_label_path in zip(all_img_dir, all_xml_label_dir):
101            desired_label_shape = imageio.imread(img_path).shape[:-1]
102
103            img_id = os.path.split(img_path)[-1]
104            dst = os.path.join(root_img_save_dir, img_id)
105            shutil.move(src=img_path, dst=dst)
106
107            _label = util.generate_labeled_array_from_xml(shape=desired_label_shape, xml_file=xml_label_path)
108            _fileid = img_id.split(".")[0]
109            imageio.imwrite(os.path.join(root_label_save_dir, f"{_fileid}.tif"), _label)
110
111    shutil.rmtree(glob(os.path.join(path, "MoNuSAC*"))[0])
112
113
114def _convert_missing_tif_from_svs(patient_dir):
115    """This function activates when we see some missing tiff inputs (and converts svs to tiff)
116
117    Cause: Happens only in the test split, maybe while converting the data, some were missed
118    Fix: We have the original svs scans. We convert the svs scans to tiff
119    """
120    all_svs_dir = sorted(glob(os.path.join(patient_dir, "*.svs")))
121    for svs_path in all_svs_dir:
122        save_tif_path = os.path.splitext(svs_path)[0] + ".tif"
123        if not os.path.exists(save_tif_path):
124            img_array = util.convert_svs_to_array(svs_path)
125            # the array from svs scans are supposed to be RGB images
126            assert img_array.shape[-1] == 3
127            imageio.imwrite(save_tif_path, img_array)
128
129
130def get_patient_id(path, split_wrt="-01Z-00-"):
131    """Gets us the patient id in the expected format
132    Input Names: "TCGA-<XX>-<XXXX>-01z-00-DX<X>-(<X>, <00X>).tif" (example: TCGA-2Z-A9JG-01Z-00-DX1_1.tif)
133    Expected: "TCGA-<XX>-<XXXX>"                                  (example: TCGA-2Z-A9JG)
134    """
135    patient_image_id = Path(path).stem
136    patient_id = patient_image_id.split(split_wrt)[0]
137    return patient_id
138
139
140def get_monusac_dataset(
141    path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False,
142    offsets=None, boundaries=False, binary=False, **kwargs
143):
144    """Dataset from https://monusac-2020.grand-challenge.org/Data/
145    """
146    _download_monusac(path, download, split)
147
148    image_paths = sorted(glob(os.path.join(path, "images", split, "*")))
149    label_paths = sorted(glob(os.path.join(path, "labels", split, "*")))
150
151    if organ_type is not None:
152        # get all patients for multiple organ selection
153        all_organ_splits = sum([ORGAN_SPLITS[split][o] for o in organ_type], [])
154
155        image_paths = [_path for _path in image_paths if get_patient_id(_path) in all_organ_splits]
156        label_paths = [_path for _path in label_paths if get_patient_id(_path) in all_organ_splits]
157
158    assert len(image_paths) == len(label_paths)
159
160    kwargs, _ = util.add_instance_label_transform(
161        kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
162    )
163    return torch_em.default_segmentation_dataset(
164        image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs
165    )
166
167
168def get_monusac_loader(
169    path, patch_shape, split, batch_size, organ_type=None, download=False,
170    offsets=None, boundaries=False, binary=False, **kwargs
171):
172    ds_kwargs, loader_kwargs = util.split_kwargs(
173        torch_em.default_segmentation_dataset, **kwargs
174    )
175    dataset = get_monusac_dataset(
176        path, patch_shape, split, organ_type=organ_type, download=download,
177        offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
178    )
179    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
180    return loader
URL = {'train': 'https://drive.google.com/uc?export=download&id=1lxMZaAPSpEHLSxGA9KKMt_r-4S8dwLhq', 'test': 'https://drive.google.com/uc?export=download&id=1G54vsOdxWY1hG7dzmkeK3r0xz9s-heyQ'}
CHECKSUM = {'train': '5b7cbeb34817a8f880d3fddc28391e48d3329a91bf3adcbd131ea149a725cd92', 'test': 'bcbc38f6bf8b149230c90c29f3428cc7b2b76f8acd7766ce9fc908fc896c2674'}
ORGAN_SPLITS = {'train': {'lung': ['TCGA-55-1594', 'TCGA-69-7760', 'TCGA-69-A59K', 'TCGA-73-4668', 'TCGA-78-7220', 'TCGA-86-7713', 'TCGA-86-8672', 'TCGA-L4-A4E5', 'TCGA-MP-A4SY', 'TCGA-MP-A4T7'], 'kidney': ['TCGA-5P-A9K0', 'TCGA-B9-A44B', 'TCGA-B9-A8YI', 'TCGA-DW-7841', 'TCGA-EV-5903', 'TCGA-F9-A97G', 'TCGA-G7-A8LD', 'TCGA-MH-A560', 'TCGA-P4-AAVK', 'TCGA-SX-A7SR', 'TCGA-UZ-A9PO', 'TCGA-UZ-A9PU'], 'breast': ['TCGA-A2-A0CV', 'TCGA-A2-A0ES', 'TCGA-B6-A0WZ', 'TCGA-BH-A18T', 'TCGA-D8-A1X5', 'TCGA-E2-A154', 'TCGA-E9-A22B', 'TCGA-E9-A22G', 'TCGA-EW-A6SD', 'TCGA-S3-AA11'], 'prostate': ['TCGA-EJ-5495', 'TCGA-EJ-5505', 'TCGA-EJ-5517', 'TCGA-G9-6342', 'TCGA-G9-6499', 'TCGA-J4-A67Q', 'TCGA-J4-A67T', 'TCGA-KK-A59X', 'TCGA-KK-A6E0', 'TCGA-KK-A7AW', 'TCGA-V1-A8WL', 'TCGA-V1-A9O9', 'TCGA-X4-A8KQ', 'TCGA-YL-A9WY']}, 'test': {'lung': ['TCGA-49-6743', 'TCGA-50-6591', 'TCGA-55-7570', 'TCGA-55-7573', 'TCGA-73-4662', 'TCGA-78-7152', 'TCGA-MP-A4T7'], 'kidney': ['TCGA-2Z-A9JG', 'TCGA-2Z-A9JN', 'TCGA-DW-7838', 'TCGA-DW-7963', 'TCGA-F9-A8NY', 'TCGA-IZ-A6M9', 'TCGA-MH-A55W'], 'breast': ['TCGA-A2-A04X', 'TCGA-A2-A0ES', 'TCGA-D8-A3Z6', 'TCGA-E2-A108', 'TCGA-EW-A6SB'], 'prostate': ['TCGA-G9-6356', 'TCGA-G9-6367', 'TCGA-VP-A87E', 'TCGA-VP-A87H', 'TCGA-X4-A8KS', 'TCGA-YL-A9WL']}}
def get_patient_id(path, split_wrt='-01Z-00-'):
131def get_patient_id(path, split_wrt="-01Z-00-"):
132    """Gets us the patient id in the expected format
133    Input Names: "TCGA-<XX>-<XXXX>-01z-00-DX<X>-(<X>, <00X>).tif" (example: TCGA-2Z-A9JG-01Z-00-DX1_1.tif)
134    Expected: "TCGA-<XX>-<XXXX>"                                  (example: TCGA-2Z-A9JG)
135    """
136    patient_image_id = Path(path).stem
137    patient_id = patient_image_id.split(split_wrt)[0]
138    return patient_id

Gets us the patient id in the expected format Input Names: "TCGA---01z-00-DX-(, <00X>).tif" (example: TCGA-2Z-A9JG-01Z-00-DX1_1.tif) Expected: "TCGA--" (example: TCGA-2Z-A9JG)

def get_monusac_dataset( path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False, offsets=None, boundaries=False, binary=False, **kwargs):
141def get_monusac_dataset(
142    path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False,
143    offsets=None, boundaries=False, binary=False, **kwargs
144):
145    """Dataset from https://monusac-2020.grand-challenge.org/Data/
146    """
147    _download_monusac(path, download, split)
148
149    image_paths = sorted(glob(os.path.join(path, "images", split, "*")))
150    label_paths = sorted(glob(os.path.join(path, "labels", split, "*")))
151
152    if organ_type is not None:
153        # get all patients for multiple organ selection
154        all_organ_splits = sum([ORGAN_SPLITS[split][o] for o in organ_type], [])
155
156        image_paths = [_path for _path in image_paths if get_patient_id(_path) in all_organ_splits]
157        label_paths = [_path for _path in label_paths if get_patient_id(_path) in all_organ_splits]
158
159    assert len(image_paths) == len(label_paths)
160
161    kwargs, _ = util.add_instance_label_transform(
162        kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
163    )
164    return torch_em.default_segmentation_dataset(
165        image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs
166    )
def get_monusac_loader( path, patch_shape, split, batch_size, organ_type=None, download=False, offsets=None, boundaries=False, binary=False, **kwargs):
169def get_monusac_loader(
170    path, patch_shape, split, batch_size, organ_type=None, download=False,
171    offsets=None, boundaries=False, binary=False, **kwargs
172):
173    ds_kwargs, loader_kwargs = util.split_kwargs(
174        torch_em.default_segmentation_dataset, **kwargs
175    )
176    dataset = get_monusac_dataset(
177        path, patch_shape, split, organ_type=organ_type, download=download,
178        offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
179    )
180    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
181    return loader