torch_em.data.datasets.medical.btcv

  1import os
  2import numpy as np
  3from glob import glob
  4from typing import Optional, List, Tuple
  5
  6import torch
  7
  8import torch_em
  9
 10from .. import util
 11from ....transform.label import OneHotTransform
 12from ... import ConcatDataset, MinSemanticLabelForegroundSampler
 13
 14
 15_PATHS = {
 16    "Abdomen": "RawData.zip",
 17    "Cervix": "CervixRawData.zip"
 18}
 19
 20
 21# https://www.synapse.org/#!Synapse:syn3193805/wiki/217789
 22ABDOMEN_ORGANS = {
 23    "spleen": 1, "right kidney": 2, "left kidney": 3, "gallbladder": 4, "esophagus": 5, "liver": 6, "stomach": 7,
 24    "aorta": 8, "inferior vena cava": 9, "portal vein and splenic vein": 10, "pancreas": 11, "right adrenal gland": 12,
 25    "left adrenal gland": 13,
 26}
 27
 28
 29# https://www.synapse.org/#!Synapse:syn3193805/wiki/217790
 30CERVICAL_ORGANS = {
 31    "bladder": 1, "uterus": 2, "rectum": 3, "small bowel": 4
 32}
 33
 34
 35def _unzip_btcv_data(path, region):
 36    _target_dir = os.path.join(path, region)
 37
 38    # if the directory exists, we assume the assorting has been done
 39    if os.path.exists(_target_dir):
 40        return
 41
 42    # now, let's prepare the directories where we unzip and store the inputs
 43    os.makedirs(_target_dir)
 44
 45    # let's unzip the objects to the desired directory
 46    zip_path = os.path.join(path, _PATHS[region])
 47    assert os.path.exists(zip_path), f"Looks like the zip file for {region} CT scans is missing."
 48    util.unzip(zip_path, _target_dir, remove=False)
 49
 50
 51def _assort_btcv_dataset(path, anatomy):
 52    if anatomy is None:  # if not specified, we take both the anatomies into account
 53        anatomy = list(_PATHS.keys())
 54
 55    if isinstance(anatomy, str):
 56        anatomy = [anatomy]
 57
 58    for _region in anatomy:
 59        assert _region in _PATHS.keys(), anatomy
 60        _unzip_btcv_data(path, _region)
 61
 62    return anatomy
 63
 64
 65def _check_organ_match_anatomy(organs, anatomy):
 66    # the sequence of anatomies assorted are:
 67    # we have a list of two list. list at first index is for abdomen, and second is for cervix
 68    from collections import defaultdict
 69    all_organs = defaultdict(list)
 70    if organs is None:  # if passed None, we return all organ labels
 71        if "Abdomen" in anatomy:
 72            all_organs["Abdomen"] = list(ABDOMEN_ORGANS.keys())
 73
 74        if "Cervix" in anatomy:
 75            all_organs["Cervix"] = list(CERVICAL_ORGANS.keys())
 76
 77        return all_organs
 78
 79    if isinstance(organs, str):
 80        organs = [organs]
 81
 82    for organ_name in organs:
 83        _match_found = False
 84        if organ_name in ABDOMEN_ORGANS and "Abdomen" in anatomy:
 85            all_organs["Abdomen"].append(organ_name)
 86            _match_found = True
 87
 88        if organ_name in CERVICAL_ORGANS and "Cervix" in anatomy:
 89            all_organs["Cervix"].append(organ_name)
 90            _match_found = True
 91
 92        if not _match_found:
 93            raise ValueError(f"{organ_name} not in {anatomy}")
 94
 95    return all_organs
 96
 97
 98def _get_organ_ids(anatomy, organs):
 99    # now, let's get the organ ids
100    for _region in anatomy:
101        _region_dict = ABDOMEN_ORGANS if _region == "Abdomen" else CERVICAL_ORGANS
102        per_region_organ_ids = [
103            _region_dict[organ_name] for organ_name in organs[_region]
104        ]
105        organs[_region] = per_region_organ_ids
106
107    return organs
108
109
110def _get_raw_and_label_paths(path, anatomy):
111    raw_paths, label_paths = {}, {}
112    for _region in anatomy:
113        raw_paths[_region] = sorted(glob(os.path.join(path, _region, "RawData", "Training", "img", "*.nii.gz")))
114        label_paths[_region] = sorted(glob(os.path.join(path, _region, "RawData", "Training", "label", "*.nii.gz")))
115    return raw_paths, label_paths
116
117
118class InstancesFromOneHot:
119    def __init__(self, class_ids, transform=None):
120        self.class_ids = class_ids
121
122        if transform is None:
123            self.transform = OneHotTransform(class_ids=self.class_ids)
124        else:
125            self.transform = transform
126
127    def __call__(self, labels):
128        labels = self.transform(labels)
129        instances = np.zeros(labels.shape[1:])
130        for i, _channel in enumerate(labels):
131            instances[_channel == 1] = i+1
132
133        return instances
134
135
136def get_btcv_dataset(
137    path: str,
138    patch_shape: Tuple[int, ...],
139    ndim: int,
140    organs: Optional[List] = None,
141    anatomy: Optional[List] = None,
142    min_foreground_fraction: float = 0.001,
143    download: bool = False,
144    **kwargs
145) -> torch.utils.data.Dataset:
146    """Dataset for multi-organ segmentation in CT scans.
147
148    This dataset is from the Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge
149    Link: https://www.synapse.org/#!Synapse:syn3193805/wiki/89480
150    Please cite it if you use this dataset for a publication.
151
152    Steps on how to get the dataset?
153        1. Join the challenge using their official website: https://www.synapse.org/#!Synapse:syn3193805
154        2. Next, go to "Files" -> (download the respective zip files)
155            - "Abdomen" -> "RawData.zip" downloads all the abdominal CT scans
156            - "Cervix" -> "CervixRawData.zip" downloads all the cervical CT scans
157        3. Provide the path to the parent directory, where the zipped file(s) have been downloaded.
158           The dataset would take care of the rest.
159
160    Args:
161        path: The path where the zip files / the prepared datasets exist.
162            - Expected initial structure: `path` should have two zip files, namely `RawData.zip` and `CervixRawData.zip`
163        patch_shape: The patch shape (for 2d or 3d patches)
164        ndim: The dimensions of the inputs (use `2` for getting 2d patches,  and `3` for getting 3d patches)
165        organ: The organs in the respective anatomical regions of choice
166            - default: None (i.e., returns labels with all organ types)
167        anatomy: The anatomical regions of choice from the provided scans
168            - default: None (i.e., returns both the available anatomies - abdomen and cervix)
169        download: (NOT SUPPORTED) Downloads the dataset
170
171    Returns:
172        dataset: The dataset for the respective splits
173    """
174    if download:
175        raise NotImplementedError(
176            "The BTCV dataset cannot be automatically download from `torch_em`. \
177            Please download the dataset (see `get_btcv_dataset` for the download steps) \
178            and provide the parent directory where the zip files are stored."
179        )
180
181    min_fraction_per_id = False if organs is None and anatomy is None else True
182
183    anatomy = _assort_btcv_dataset(path, anatomy)
184    organs = _check_organ_match_anatomy(organs, anatomy)
185    organs = _get_organ_ids(anatomy, organs)
186    raw_paths, label_paths = _get_raw_and_label_paths(path, anatomy)
187
188    assert len(raw_paths) == len(label_paths)
189
190    all_datasets = []
191    for per_anatomy in anatomy:
192        semantic_ids = organs[per_anatomy]
193        sampler = MinSemanticLabelForegroundSampler(
194            semantic_ids=semantic_ids,
195            min_fraction=min_foreground_fraction,
196            min_fraction_per_id=min_fraction_per_id
197        )
198        label_transform = InstancesFromOneHot(class_ids=semantic_ids)
199        dataset = torch_em.default_segmentation_dataset(
200            raw_paths[per_anatomy], "data", label_paths[per_anatomy], "data",
201            patch_shape, ndim=ndim, sampler=sampler, label_transform=label_transform,
202            **kwargs
203        )
204        for _ds in dataset.datasets:
205            _ds.max_sampling_attempts = 5000
206
207        all_datasets.append(dataset)
208
209    return ConcatDataset(*all_datasets)
210
211
212def get_btcv_loader(
213    path,
214    patch_shape,
215    batch_size,
216    ndim,
217    organs=None,
218    anatomy=None,
219    min_foreground_fraction=0.001,
220    download=False,
221    **kwargs
222):
223    """Dataloader for multi-organ segmentation in CT scans. See `get_btcv_dataset` for details."""
224    ds_kwargs, loader_kwargs = util.split_kwargs(
225        torch_em.default_segmentation_dataset, **kwargs
226    )
227    ds = get_btcv_dataset(path, patch_shape, ndim, organs, anatomy, min_foreground_fraction, download, **ds_kwargs)
228    loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
229    return loader
ABDOMEN_ORGANS = {'spleen': 1, 'right kidney': 2, 'left kidney': 3, 'gallbladder': 4, 'esophagus': 5, 'liver': 6, 'stomach': 7, 'aorta': 8, 'inferior vena cava': 9, 'portal vein and splenic vein': 10, 'pancreas': 11, 'right adrenal gland': 12, 'left adrenal gland': 13}
CERVICAL_ORGANS = {'bladder': 1, 'uterus': 2, 'rectum': 3, 'small bowel': 4}
class InstancesFromOneHot:
119class InstancesFromOneHot:
120    def __init__(self, class_ids, transform=None):
121        self.class_ids = class_ids
122
123        if transform is None:
124            self.transform = OneHotTransform(class_ids=self.class_ids)
125        else:
126            self.transform = transform
127
128    def __call__(self, labels):
129        labels = self.transform(labels)
130        instances = np.zeros(labels.shape[1:])
131        for i, _channel in enumerate(labels):
132            instances[_channel == 1] = i+1
133
134        return instances
InstancesFromOneHot(class_ids, transform=None)
120    def __init__(self, class_ids, transform=None):
121        self.class_ids = class_ids
122
123        if transform is None:
124            self.transform = OneHotTransform(class_ids=self.class_ids)
125        else:
126            self.transform = transform
class_ids
def get_btcv_dataset( path: str, patch_shape: Tuple[int, ...], ndim: int, organs: Optional[List] = None, anatomy: Optional[List] = None, min_foreground_fraction: float = 0.001, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
137def get_btcv_dataset(
138    path: str,
139    patch_shape: Tuple[int, ...],
140    ndim: int,
141    organs: Optional[List] = None,
142    anatomy: Optional[List] = None,
143    min_foreground_fraction: float = 0.001,
144    download: bool = False,
145    **kwargs
146) -> torch.utils.data.Dataset:
147    """Dataset for multi-organ segmentation in CT scans.
148
149    This dataset is from the Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge
150    Link: https://www.synapse.org/#!Synapse:syn3193805/wiki/89480
151    Please cite it if you use this dataset for a publication.
152
153    Steps on how to get the dataset?
154        1. Join the challenge using their official website: https://www.synapse.org/#!Synapse:syn3193805
155        2. Next, go to "Files" -> (download the respective zip files)
156            - "Abdomen" -> "RawData.zip" downloads all the abdominal CT scans
157            - "Cervix" -> "CervixRawData.zip" downloads all the cervical CT scans
158        3. Provide the path to the parent directory, where the zipped file(s) have been downloaded.
159           The dataset would take care of the rest.
160
161    Args:
162        path: The path where the zip files / the prepared datasets exist.
163            - Expected initial structure: `path` should have two zip files, namely `RawData.zip` and `CervixRawData.zip`
164        patch_shape: The patch shape (for 2d or 3d patches)
165        ndim: The dimensions of the inputs (use `2` for getting 2d patches,  and `3` for getting 3d patches)
166        organ: The organs in the respective anatomical regions of choice
167            - default: None (i.e., returns labels with all organ types)
168        anatomy: The anatomical regions of choice from the provided scans
169            - default: None (i.e., returns both the available anatomies - abdomen and cervix)
170        download: (NOT SUPPORTED) Downloads the dataset
171
172    Returns:
173        dataset: The dataset for the respective splits
174    """
175    if download:
176        raise NotImplementedError(
177            "The BTCV dataset cannot be automatically download from `torch_em`. \
178            Please download the dataset (see `get_btcv_dataset` for the download steps) \
179            and provide the parent directory where the zip files are stored."
180        )
181
182    min_fraction_per_id = False if organs is None and anatomy is None else True
183
184    anatomy = _assort_btcv_dataset(path, anatomy)
185    organs = _check_organ_match_anatomy(organs, anatomy)
186    organs = _get_organ_ids(anatomy, organs)
187    raw_paths, label_paths = _get_raw_and_label_paths(path, anatomy)
188
189    assert len(raw_paths) == len(label_paths)
190
191    all_datasets = []
192    for per_anatomy in anatomy:
193        semantic_ids = organs[per_anatomy]
194        sampler = MinSemanticLabelForegroundSampler(
195            semantic_ids=semantic_ids,
196            min_fraction=min_foreground_fraction,
197            min_fraction_per_id=min_fraction_per_id
198        )
199        label_transform = InstancesFromOneHot(class_ids=semantic_ids)
200        dataset = torch_em.default_segmentation_dataset(
201            raw_paths[per_anatomy], "data", label_paths[per_anatomy], "data",
202            patch_shape, ndim=ndim, sampler=sampler, label_transform=label_transform,
203            **kwargs
204        )
205        for _ds in dataset.datasets:
206            _ds.max_sampling_attempts = 5000
207
208        all_datasets.append(dataset)
209
210    return ConcatDataset(*all_datasets)

Dataset for multi-organ segmentation in CT scans.

This dataset is from the Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge Link: https://www.synapse.org/#!Synapse:syn3193805/wiki/89480 Please cite it if you use this dataset for a publication.

Steps on how to get the dataset? 1. Join the challenge using their official website: https://www.synapse.org/#!Synapse:syn3193805 2. Next, go to "Files" -> (download the respective zip files) - "Abdomen" -> "RawData.zip" downloads all the abdominal CT scans - "Cervix" -> "CervixRawData.zip" downloads all the cervical CT scans 3. Provide the path to the parent directory, where the zipped file(s) have been downloaded. The dataset would take care of the rest.

Arguments:
  • path: The path where the zip files / the prepared datasets exist.
    • Expected initial structure: path should have two zip files, namely RawData.zip and CervixRawData.zip
  • patch_shape: The patch shape (for 2d or 3d patches)
  • ndim: The dimensions of the inputs (use 2 for getting 2d patches, and 3 for getting 3d patches)
  • organ: The organs in the respective anatomical regions of choice
    • default: None (i.e., returns labels with all organ types)
  • anatomy: The anatomical regions of choice from the provided scans
    • default: None (i.e., returns both the available anatomies - abdomen and cervix)
  • download: (NOT SUPPORTED) Downloads the dataset
Returns:

dataset: The dataset for the respective splits

def get_btcv_loader( path, patch_shape, batch_size, ndim, organs=None, anatomy=None, min_foreground_fraction=0.001, download=False, **kwargs):
213def get_btcv_loader(
214    path,
215    patch_shape,
216    batch_size,
217    ndim,
218    organs=None,
219    anatomy=None,
220    min_foreground_fraction=0.001,
221    download=False,
222    **kwargs
223):
224    """Dataloader for multi-organ segmentation in CT scans. See `get_btcv_dataset` for details."""
225    ds_kwargs, loader_kwargs = util.split_kwargs(
226        torch_em.default_segmentation_dataset, **kwargs
227    )
228    ds = get_btcv_dataset(path, patch_shape, ndim, organs, anatomy, min_foreground_fraction, download, **ds_kwargs)
229    loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
230    return loader

Dataloader for multi-organ segmentation in CT scans. See get_btcv_dataset for details.