torch_em.data.datasets.medical.sega

The SegA dataset contains annotations for aorta segmentation in CT scans.

The dataset is from the publication https://doi.org/10.1007/978-3-031-53241-2. Please cite it if you use this dataset for your research.

  1"""The SegA dataset contains annotations for aorta segmentation in CT scans.
  2
  3The dataset is from the publication https://doi.org/10.1007/978-3-031-53241-2.
  4Please cite it if you use this dataset for your research.
  5"""
  6
  7import os
  8from glob import glob
  9from pathlib import Path
 10from natsort import natsorted
 11from typing import Union, Tuple, Optional, Literal, List
 12
 13from torch.utils.data import Dataset, DataLoader
 14
 15import torch_em
 16
 17from .. import util
 18
 19
 20URL = {
 21    "kits": "https://figshare.com/ndownloader/files/30950821",
 22    "rider": "https://figshare.com/ndownloader/files/30950914",
 23    "dongyang": "https://figshare.com/ndownloader/files/30950971"
 24}
 25
 26CHECKSUMS = {
 27    "kits": "6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e",
 28    "rider": "7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa",
 29    "dongyang": "0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541"
 30}
 31
 32ZIPFILES = {
 33    "kits": "KiTS.zip",
 34    "rider": "Rider.zip",
 35    "dongyang": "Dongyang.zip"
 36}
 37
 38
 39def get_sega_data(
 40    path: Union[os.PathLike, str],
 41    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
 42    download: bool = False
 43) -> str:
 44    """Dwonload the SegA dataset.
 45
 46    Args:
 47        path: Filepath to a folder where the data is downloaded for further processing.
 48        data_choice: The choice of dataset.
 49        download: Whether to download the data if it is not present.
 50
 51    Returns:
 52        Filepath where the data is downloaded.
 53    """
 54    data_choice = data_choice.lower()
 55    zip_fid = ZIPFILES[data_choice]
 56    data_dir = os.path.join(path, Path(zip_fid).stem)
 57    if os.path.exists(data_dir):
 58        return data_dir
 59
 60    os.makedirs(path, exist_ok=True)
 61
 62    zip_path = os.path.join(path, zip_fid)
 63    util.download_source(path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice])
 64    util.unzip(zip_path=zip_path, dst=path)
 65
 66    return data_dir
 67
 68
 69def get_sega_paths(
 70    path: Union[os.PathLike, str],
 71    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
 72    download: bool = False
 73) -> Tuple[List[str], List[str]]:
 74    """Get paths to the SegA data.
 75
 76    Args:
 77        path: Filepath to a folder where the data is downloaded for further processing.
 78        data_choice: The choice of dataset.
 79        download: Whether to download the data if it is not present.
 80
 81    Returns:
 82        List of filepaths for the image data.
 83        List of filepaths for the label data.
 84    """
 85    if data_choice is None:
 86        data_choices = URL.keys()
 87    else:
 88        if isinstance(data_choice, str):
 89            data_choices = [data_choice]
 90
 91    data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices]
 92
 93    image_paths, gt_paths = [], []
 94    for data_dir in data_dirs:
 95        all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd"))
 96        for volume_path in all_volumes_paths:
 97            if volume_path.endswith(".seg.nrrd"):
 98                gt_paths.append(volume_path)
 99            else:
100                image_paths.append(volume_path)
101
102    # now let's wrap the volumes to nifti format
103    fimage_dir = os.path.join(path, "data", "images")
104    fgt_dir = os.path.join(path, "data", "labels")
105
106    os.makedirs(fimage_dir, exist_ok=True)
107    os.makedirs(fgt_dir, exist_ok=True)
108
109    fimage_paths, fgt_paths = [], []
110    for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)):
111        fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz")
112        fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz")
113
114        fimage_paths.append(fimage_path)
115        fgt_paths.append(fgt_path)
116
117        if os.path.exists(fimage_path) and os.path.exists(fgt_path):
118            continue
119
120        import nrrd
121        import numpy as np
122        import nibabel as nib
123
124        image = nrrd.read(image_path)[0]
125        gt = nrrd.read(gt_path)[0]
126
127        image_nifti = nib.Nifti2Image(image, np.eye(4))
128        gt_nifti = nib.Nifti2Image(gt, np.eye(4))
129
130        nib.save(image_nifti, fimage_path)
131        nib.save(gt_nifti, fgt_path)
132
133    return natsorted(fimage_paths), natsorted(fgt_paths)
134
135
136def get_sega_dataset(
137    path: Union[os.PathLike, str],
138    patch_shape: Tuple[int, ...],
139    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
140    resize_inputs: bool = False,
141    download: bool = False,
142    **kwargs
143) -> Dataset:
144    """Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans.
145
146    Args:
147        path: Filepath to a folder where the data is downloaded for further processing.
148        patch_shape: The patch shape to use for training.
149        data_choice: The choice of dataset.
150        resize_inputs: Whether to resize the inputs to the patch shape.
151        download: Whether to download the data if it is not present.
152        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
153
154    Returns:
155        The segmentation dataset.
156    """
157    image_paths, gt_paths = get_sega_paths(path, data_choice, download)
158
159    if resize_inputs:
160        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
161        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
162            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs,
163        )
164
165    return torch_em.default_segmentation_dataset(
166        raw_paths=image_paths,
167        raw_key="data",
168        label_paths=gt_paths,
169        label_key="data",
170        patch_shape=patch_shape,
171        is_seg_dataset=True,
172        **kwargs
173    )
174
175
176def get_sega_loader(
177    path: Union[os.PathLike, str],
178    batch_size: int,
179    patch_shape: Tuple[int, ...],
180    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
181    resize_inputs: bool = False,
182    download: bool = False,
183    **kwargs
184) -> DataLoader:
185    """Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans.
186
187    Args:
188        path: Filepath to a folder where the data is downloaded for further processing.
189        batch_size: The batch size for training.
190        patch_shape: The patch shape to use for training.
191        data_choice: The choice of dataset.
192        resize_inputs: Whether to resize the inputs to the patch shape.
193        download: Whether to download the data if it is not present.
194        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
195
196    Returns:
197        The DataLoader.
198    """
199    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
200    dataset = get_sega_dataset(path, patch_shape, data_choice, resize_inputs, download, **ds_kwargs)
201    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = {'kits': 'https://figshare.com/ndownloader/files/30950821', 'rider': 'https://figshare.com/ndownloader/files/30950914', 'dongyang': 'https://figshare.com/ndownloader/files/30950971'}
CHECKSUMS = {'kits': '6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e', 'rider': '7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa', 'dongyang': '0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541'}
ZIPFILES = {'kits': 'KiTS.zip', 'rider': 'Rider.zip', 'dongyang': 'Dongyang.zip'}
def get_sega_data( path: Union[os.PathLike, str], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, download: bool = False) -> str:
40def get_sega_data(
41    path: Union[os.PathLike, str],
42    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
43    download: bool = False
44) -> str:
45    """Dwonload the SegA dataset.
46
47    Args:
48        path: Filepath to a folder where the data is downloaded for further processing.
49        data_choice: The choice of dataset.
50        download: Whether to download the data if it is not present.
51
52    Returns:
53        Filepath where the data is downloaded.
54    """
55    data_choice = data_choice.lower()
56    zip_fid = ZIPFILES[data_choice]
57    data_dir = os.path.join(path, Path(zip_fid).stem)
58    if os.path.exists(data_dir):
59        return data_dir
60
61    os.makedirs(path, exist_ok=True)
62
63    zip_path = os.path.join(path, zip_fid)
64    util.download_source(path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice])
65    util.unzip(zip_path=zip_path, dst=path)
66
67    return data_dir

Dwonload the SegA dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • data_choice: The choice of dataset.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_sega_paths( path: Union[os.PathLike, str], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, download: bool = False) -> Tuple[List[str], List[str]]:
 70def get_sega_paths(
 71    path: Union[os.PathLike, str],
 72    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
 73    download: bool = False
 74) -> Tuple[List[str], List[str]]:
 75    """Get paths to the SegA data.
 76
 77    Args:
 78        path: Filepath to a folder where the data is downloaded for further processing.
 79        data_choice: The choice of dataset.
 80        download: Whether to download the data if it is not present.
 81
 82    Returns:
 83        List of filepaths for the image data.
 84        List of filepaths for the label data.
 85    """
 86    if data_choice is None:
 87        data_choices = URL.keys()
 88    else:
 89        if isinstance(data_choice, str):
 90            data_choices = [data_choice]
 91
 92    data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices]
 93
 94    image_paths, gt_paths = [], []
 95    for data_dir in data_dirs:
 96        all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd"))
 97        for volume_path in all_volumes_paths:
 98            if volume_path.endswith(".seg.nrrd"):
 99                gt_paths.append(volume_path)
100            else:
101                image_paths.append(volume_path)
102
103    # now let's wrap the volumes to nifti format
104    fimage_dir = os.path.join(path, "data", "images")
105    fgt_dir = os.path.join(path, "data", "labels")
106
107    os.makedirs(fimage_dir, exist_ok=True)
108    os.makedirs(fgt_dir, exist_ok=True)
109
110    fimage_paths, fgt_paths = [], []
111    for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)):
112        fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz")
113        fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz")
114
115        fimage_paths.append(fimage_path)
116        fgt_paths.append(fgt_path)
117
118        if os.path.exists(fimage_path) and os.path.exists(fgt_path):
119            continue
120
121        import nrrd
122        import numpy as np
123        import nibabel as nib
124
125        image = nrrd.read(image_path)[0]
126        gt = nrrd.read(gt_path)[0]
127
128        image_nifti = nib.Nifti2Image(image, np.eye(4))
129        gt_nifti = nib.Nifti2Image(gt, np.eye(4))
130
131        nib.save(image_nifti, fimage_path)
132        nib.save(gt_nifti, fgt_path)
133
134    return natsorted(fimage_paths), natsorted(fgt_paths)

Get paths to the SegA data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • data_choice: The choice of dataset.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_sega_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
137def get_sega_dataset(
138    path: Union[os.PathLike, str],
139    patch_shape: Tuple[int, ...],
140    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
141    resize_inputs: bool = False,
142    download: bool = False,
143    **kwargs
144) -> Dataset:
145    """Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans.
146
147    Args:
148        path: Filepath to a folder where the data is downloaded for further processing.
149        patch_shape: The patch shape to use for training.
150        data_choice: The choice of dataset.
151        resize_inputs: Whether to resize the inputs to the patch shape.
152        download: Whether to download the data if it is not present.
153        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
154
155    Returns:
156        The segmentation dataset.
157    """
158    image_paths, gt_paths = get_sega_paths(path, data_choice, download)
159
160    if resize_inputs:
161        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
162        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
163            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs,
164        )
165
166    return torch_em.default_segmentation_dataset(
167        raw_paths=image_paths,
168        raw_key="data",
169        label_paths=gt_paths,
170        label_key="data",
171        patch_shape=patch_shape,
172        is_seg_dataset=True,
173        **kwargs
174    )

Get the SegA dataset for segmentation of aorta in computed tomography angiography (CTA) scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • data_choice: The choice of dataset.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_sega_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], data_choice: Optional[Literal['KiTS', 'Rider', 'Dongyang']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
177def get_sega_loader(
178    path: Union[os.PathLike, str],
179    batch_size: int,
180    patch_shape: Tuple[int, ...],
181    data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
182    resize_inputs: bool = False,
183    download: bool = False,
184    **kwargs
185) -> DataLoader:
186    """Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans.
187
188    Args:
189        path: Filepath to a folder where the data is downloaded for further processing.
190        batch_size: The batch size for training.
191        patch_shape: The patch shape to use for training.
192        data_choice: The choice of dataset.
193        resize_inputs: Whether to resize the inputs to the patch shape.
194        download: Whether to download the data if it is not present.
195        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
196
197    Returns:
198        The DataLoader.
199    """
200    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
201    dataset = get_sega_dataset(path, patch_shape, data_choice, resize_inputs, download, **ds_kwargs)
202    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the SegA dataloader for segmentation of aorta in computed tomography angiography (CTA) scans.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • data_choice: The choice of dataset.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.