torch_em.data.datasets.histopathology.bcss

  1import os
  2import shutil
  3from glob import glob
  4from pathlib import Path
  5
  6from sklearn.model_selection import train_test_split
  7
  8import torch
  9import torch_em
 10from torch_em.data.datasets import util
 11from torch_em.data import ImageCollectionDataset
 12
 13
 14URL = "https://drive.google.com/drive/folders/1zqbdkQF8i5cEmZOGmbdQm-EP8dRYtvss?usp=sharing"
 15
 16
 17# TODO
 18CHECKSUM = None
 19
 20
 21TEST_LIST = [
 22    "TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500", "TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500",
 23    "TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500", "TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500",
 24    "TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500", "TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500",
 25    "TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500", "TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500",
 26    "TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500", "TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500",
 27    "TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500", "TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500",
 28    "TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500", "TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500",
 29    "TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500"
 30]
 31
 32
 33def _download_bcss_dataset(path, download):
 34    """Current recommendation:
 35        - download the folder from URL manually
 36        - use the consortium's git repo to download the dataset (https://github.com/PathologyDataScience/BCSS)
 37    """
 38    raise NotImplementedError("Please download the dataset using the drive link / git repo directly")
 39
 40    # FIXME: limitation for the installation below:
 41    #   - only downloads first 50 files - due to `gdown`'s download folder function
 42    #   - (optional) clone their git repo to download their data
 43    util.download_source_gdrive(path=path, url=URL, download=download, checksum=CHECKSUM, download_type="folder")
 44
 45
 46def _get_image_and_label_paths(path):
 47    # when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized`
 48    # when getting the files from the git repo's command line feature, the input images are stored under `images`
 49    if os.path.exists(os.path.join(path, "images")):
 50        image_paths = sorted(glob(os.path.join(path, "images", "*")))
 51        label_paths = sorted(glob(os.path.join(path, "masks", "*")))
 52    elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")):
 53        image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*")))
 54        label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*")))
 55    else:
 56        raise ValueError(
 57            "Please check the image directory. "
 58            "If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\""
 59        )
 60
 61    return image_paths, label_paths
 62
 63
 64def _assort_bcss_data(path, download):
 65    if download:
 66        _download_bcss_dataset(path, download)
 67
 68    if os.path.exists(os.path.join(path, "train")) and os.path.exists(os.path.join(path, "test")):
 69        return
 70
 71    all_image_paths, all_label_paths = _get_image_and_label_paths(path)
 72
 73    train_img_dir, train_lab_dir = os.path.join(path, "train", "images"), os.path.join(path, "train", "masks")
 74    test_img_dir, test_lab_dir = os.path.join(path, "test", "images"), os.path.join(path, "test", "masks")
 75    os.makedirs(train_img_dir, exist_ok=True)
 76    os.makedirs(train_lab_dir, exist_ok=True)
 77    os.makedirs(test_img_dir, exist_ok=True)
 78    os.makedirs(test_lab_dir, exist_ok=True)
 79
 80    for image_path, label_path in zip(all_image_paths, all_label_paths):
 81        img_idx, label_idx = os.path.split(image_path)[-1], os.path.split(label_path)[-1]
 82        if Path(image_path).stem in TEST_LIST:
 83            # move image and label to test
 84            dst_img_path, dst_lab_path = os.path.join(test_img_dir, img_idx), os.path.join(test_lab_dir, label_idx)
 85            shutil.copy(src=image_path, dst=dst_img_path)
 86            shutil.copy(src=label_path, dst=dst_lab_path)
 87        else:
 88            # move image and label to train
 89            dst_img_path, dst_lab_path = os.path.join(train_img_dir, img_idx), os.path.join(train_lab_dir, label_idx)
 90            shutil.copy(src=image_path, dst=dst_img_path)
 91            shutil.copy(src=label_path, dst=dst_lab_path)
 92
 93
 94def get_bcss_dataset(
 95    path, patch_shape, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs
 96):
 97    """Dataset for breast cancer tissue segmentation in histopathology.
 98
 99    This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/.
100    Please cite this paper (https://doi.org/10.1093/bioinformatics/btz083) if you use this dataset for a publication.
101
102    NOTE: There are multiple semantic instances in tissue labels. Below mentioned are their respective index details:
103        - 0: outside_roi (~background)
104        - 1: tumor
105        - 2: stroma
106        - 3: lymphocytic_infiltrate
107        - 4: necrosis_or_debris
108        - 5: glandular_secretions
109        - 6: blood
110        - 7: exclude
111        - 8: metaplasia_NOS
112        - 9: fat
113        - 10: plasma_cells
114        - 11: other_immune_infiltrate
115        - 12: mucoid_material
116        - 13: normal_acinus_or_duct
117        - 14: lymphatics
118        - 15: undetermined
119        - 16: nerve
120        - 17: skin_adnexa
121        - 18: blood_vessel
122        - 19: angioinvasion
123        - 20: dcis
124        - 21: other
125    """
126    _assort_bcss_data(path, download)
127
128    if split is None:
129        image_paths = sorted(glob(os.path.join(path, "*", "images", "*")))
130        label_paths = sorted(glob(os.path.join(path, "*", "masks", "*")))
131    else:
132        assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits"
133        if split == "test":
134            image_paths = sorted(glob(os.path.join(path, "test", "images", "*")))
135            label_paths = sorted(glob(os.path.join(path, "test", "masks", "*")))
136        else:
137            image_paths = sorted(glob(os.path.join(path, "train", "images", "*")))
138            label_paths = sorted(glob(os.path.join(path, "train", "masks", "*")))
139
140            (train_image_paths, val_image_paths,
141             train_label_paths, val_label_paths) = train_test_split(
142                image_paths, label_paths, test_size=val_fraction, random_state=42
143            )
144
145            image_paths = train_image_paths if split == "train" else val_image_paths
146            label_paths = train_label_paths if split == "train" else val_label_paths
147
148    assert len(image_paths) == len(label_paths)
149
150    dataset = ImageCollectionDataset(
151        image_paths, label_paths, patch_shape=patch_shape, label_dtype=label_dtype, **kwargs
152    )
153    return dataset
154
155
156def get_bcss_loader(
157        path, patch_shape, batch_size, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs
158):
159    """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details."""
160    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
161    dataset = get_bcss_dataset(
162        path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs
163    )
164    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
165    return loader
URL = 'https://drive.google.com/drive/folders/1zqbdkQF8i5cEmZOGmbdQm-EP8dRYtvss?usp=sharing'
CHECKSUM = None
TEST_LIST = ['TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500', 'TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500', 'TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500', 'TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500', 'TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500', 'TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500', 'TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500', 'TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500', 'TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500', 'TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500', 'TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500', 'TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500', 'TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500', 'TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500', 'TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500']
def get_bcss_dataset( path, patch_shape, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs):
 95def get_bcss_dataset(
 96    path, patch_shape, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs
 97):
 98    """Dataset for breast cancer tissue segmentation in histopathology.
 99
100    This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/.
101    Please cite this paper (https://doi.org/10.1093/bioinformatics/btz083) if you use this dataset for a publication.
102
103    NOTE: There are multiple semantic instances in tissue labels. Below mentioned are their respective index details:
104        - 0: outside_roi (~background)
105        - 1: tumor
106        - 2: stroma
107        - 3: lymphocytic_infiltrate
108        - 4: necrosis_or_debris
109        - 5: glandular_secretions
110        - 6: blood
111        - 7: exclude
112        - 8: metaplasia_NOS
113        - 9: fat
114        - 10: plasma_cells
115        - 11: other_immune_infiltrate
116        - 12: mucoid_material
117        - 13: normal_acinus_or_duct
118        - 14: lymphatics
119        - 15: undetermined
120        - 16: nerve
121        - 17: skin_adnexa
122        - 18: blood_vessel
123        - 19: angioinvasion
124        - 20: dcis
125        - 21: other
126    """
127    _assort_bcss_data(path, download)
128
129    if split is None:
130        image_paths = sorted(glob(os.path.join(path, "*", "images", "*")))
131        label_paths = sorted(glob(os.path.join(path, "*", "masks", "*")))
132    else:
133        assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits"
134        if split == "test":
135            image_paths = sorted(glob(os.path.join(path, "test", "images", "*")))
136            label_paths = sorted(glob(os.path.join(path, "test", "masks", "*")))
137        else:
138            image_paths = sorted(glob(os.path.join(path, "train", "images", "*")))
139            label_paths = sorted(glob(os.path.join(path, "train", "masks", "*")))
140
141            (train_image_paths, val_image_paths,
142             train_label_paths, val_label_paths) = train_test_split(
143                image_paths, label_paths, test_size=val_fraction, random_state=42
144            )
145
146            image_paths = train_image_paths if split == "train" else val_image_paths
147            label_paths = train_label_paths if split == "train" else val_label_paths
148
149    assert len(image_paths) == len(label_paths)
150
151    dataset = ImageCollectionDataset(
152        image_paths, label_paths, patch_shape=patch_shape, label_dtype=label_dtype, **kwargs
153    )
154    return dataset

Dataset for breast cancer tissue segmentation in histopathology.

This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. Please cite this paper (https://doi.org/10.1093/bioinformatics/btz083) if you use this dataset for a publication.

NOTE: There are multiple semantic instances in tissue labels. Below mentioned are their respective index details: - 0: outside_roi (~background) - 1: tumor - 2: stroma - 3: lymphocytic_infiltrate - 4: necrosis_or_debris - 5: glandular_secretions - 6: blood - 7: exclude - 8: metaplasia_NOS - 9: fat - 10: plasma_cells - 11: other_immune_infiltrate - 12: mucoid_material - 13: normal_acinus_or_duct - 14: lymphatics - 15: undetermined - 16: nerve - 17: skin_adnexa - 18: blood_vessel - 19: angioinvasion - 20: dcis - 21: other

def get_bcss_loader( path, patch_shape, batch_size, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs):
157def get_bcss_loader(
158        path, patch_shape, batch_size, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs
159):
160    """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details."""
161    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
162    dataset = get_bcss_dataset(
163        path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs
164    )
165    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
166    return loader

Dataloader for breast cancer tissue segmentation in histopathology. See get_bcss_dataset for details.