torch_em.data.datasets.medical.acdc

The ACDC dataset contains annotations for multi-structure segmentation in cardiac MRI.

The labels have the following mapping:

  • 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity)

The database is located at https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb

The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502. Please cite it if you use this dataset for a publication.

  1"""The ACDC dataset contains annotations for multi-structure segmentation in cardiac MRI.
  2
  3The labels have the following mapping:
  4- 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity)
  5
  6The database is located at
  7https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb
  8
  9The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502.
 10Please cite it if you use this dataset for a publication.
 11"""
 12
 13import os
 14from glob import glob
 15from natsort import natsorted
 16from typing import Union, Tuple, Literal, List
 17
 18from torch.utils.data import Dataset, DataLoader
 19
 20import torch_em
 21
 22from .. import util
 23from ... import ConcatDataset
 24
 25
 26URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download"
 27CHECKSUM = "2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e"
 28
 29
 30def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 31    """Download the ACDC dataset.
 32
 33    Args:
 34        path: Filepath to a folder where the data is downloaded for further processing.
 35        download: Whether to download the data if it is not present.
 36
 37    Returns:
 38        Filepath where the data is downlaoded.
 39    """
 40    zip_path = os.path.join(path, "ACDC.zip")
 41    trg_dir = os.path.join(path, "ACDC")
 42    if os.path.exists(trg_dir):
 43        return trg_dir
 44
 45    os.makedirs(path, exist_ok=True)
 46
 47    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 48    util.unzip(zip_path=zip_path, dst=path, remove=False)
 49
 50    return trg_dir
 51
 52
 53def get_acdc_paths(
 54    path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False
 55) -> Tuple[List[str], List[str]]:
 56    """Get paths to the ACDC data.
 57
 58    Args:
 59        path: Filepath to a folder where the data is downloaded for further processing.
 60        download: Whether to download the data if it is not present.
 61
 62    Returns:
 63        List of filepaths for the image data.
 64        List of filepaths for the label data.
 65    """
 66    root_dir = get_acdc_data(path=path, download=download)
 67
 68    if split == "train":
 69        input_dir = os.path.join(root_dir, "database", "training")
 70    elif split == "test":
 71        input_dir = os.path.join(root_dir, "database", "testing")
 72    else:
 73        raise ValueError(f"'{split}' is not a valid data split.")
 74
 75    all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*")))
 76
 77    image_paths, gt_paths = [], []
 78    for per_patient_dir in all_patient_dirs:
 79        # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES))
 80        # the "frames" denote - ED and ES phase instances, which have manual segmentations.
 81        all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz"))
 82        for vol_path in all_volumes:
 83            sres = vol_path.find("gt")
 84            if sres == -1:  # this means the search was invalid, hence it's the  mri volume
 85                image_paths.append(vol_path)
 86            else:  # this means that the search went through, hence it's the ground truth volume
 87                gt_paths.append(vol_path)
 88
 89    return natsorted(image_paths), natsorted(gt_paths)
 90
 91
 92def get_acdc_dataset(
 93    path: Union[os.PathLike, str],
 94    patch_shape: Tuple[int, ...],
 95    split: Literal["train", "test"],
 96    resize_inputs: bool = False,
 97    download: bool = False,
 98    **kwargs
 99) -> Dataset:
100    """Get the ACDC dataset for cardiac structure segmentation.
101
102    Args:
103        path: Filepath to a folder where the data is downloaded for further processing.
104        patch_shape: The patch shape to use for training.
105        split: The choice of data split.
106        resize_inputs: Whether to resize inputs to the desired patch shape.
107        download: Whether to download the data if it is not present.
108        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
109
110    Returns:
111        The segmentation dataset.
112    """
113    image_paths, gt_paths = get_acdc_paths(path, split, download)
114
115    if resize_inputs:
116        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
117        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
118            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
119        )
120
121    all_datasets = []
122    for image_path, gt_path in zip(image_paths, gt_paths):
123        per_vol_ds = torch_em.default_segmentation_dataset(
124            raw_paths=image_path,
125            raw_key="data",
126            label_paths=gt_path,
127            label_key="data",
128            patch_shape=patch_shape,
129            is_seg_dataset=True,
130            **kwargs
131        )
132        all_datasets.append(per_vol_ds)
133
134    return ConcatDataset(*all_datasets)
135
136
137def get_acdc_loader(
138    path: Union[os.PathLike, str],
139    batch_size: int,
140    patch_shape: Tuple[int, ...],
141    split: Literal["train", "test"],
142    resize_inputs: bool = False,
143    download: bool = False,
144    **kwargs
145) -> DataLoader:
146    """Get the ACDC dataloader for cardiac structure segmentation.
147
148    Args:
149        path: Filepath to a folder where the data is downloaded for further processing.
150        batch_size: The batch size for training.
151        patch_shape: The patch shape to use for training.
152        split: The choice of data split.
153        resize_inputs: Whether to resize inputs to the desired patch shape.
154        download: Whether to download the data if it is not present.
155        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
156
157    Returns:
158        The DataLoader.
159    """
160    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
161    dataset = get_acdc_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
162    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download'
CHECKSUM = '2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e'
def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str:
32    """Download the ACDC dataset.
33
34    Args:
35        path: Filepath to a folder where the data is downloaded for further processing.
36        download: Whether to download the data if it is not present.
37
38    Returns:
39        Filepath where the data is downlaoded.
40    """
41    zip_path = os.path.join(path, "ACDC.zip")
42    trg_dir = os.path.join(path, "ACDC")
43    if os.path.exists(trg_dir):
44        return trg_dir
45
46    os.makedirs(path, exist_ok=True)
47
48    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
49    util.unzip(zip_path=zip_path, dst=path, remove=False)
50
51    return trg_dir

Download the ACDC dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downlaoded.

def get_acdc_paths( path: Union[os.PathLike, str], split: Literal['train', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
54def get_acdc_paths(
55    path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False
56) -> Tuple[List[str], List[str]]:
57    """Get paths to the ACDC data.
58
59    Args:
60        path: Filepath to a folder where the data is downloaded for further processing.
61        download: Whether to download the data if it is not present.
62
63    Returns:
64        List of filepaths for the image data.
65        List of filepaths for the label data.
66    """
67    root_dir = get_acdc_data(path=path, download=download)
68
69    if split == "train":
70        input_dir = os.path.join(root_dir, "database", "training")
71    elif split == "test":
72        input_dir = os.path.join(root_dir, "database", "testing")
73    else:
74        raise ValueError(f"'{split}' is not a valid data split.")
75
76    all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*")))
77
78    image_paths, gt_paths = [], []
79    for per_patient_dir in all_patient_dirs:
80        # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES))
81        # the "frames" denote - ED and ES phase instances, which have manual segmentations.
82        all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz"))
83        for vol_path in all_volumes:
84            sres = vol_path.find("gt")
85            if sres == -1:  # this means the search was invalid, hence it's the  mri volume
86                image_paths.append(vol_path)
87            else:  # this means that the search went through, hence it's the ground truth volume
88                gt_paths.append(vol_path)
89
90    return natsorted(image_paths), natsorted(gt_paths)

Get paths to the ACDC data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_acdc_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
 93def get_acdc_dataset(
 94    path: Union[os.PathLike, str],
 95    patch_shape: Tuple[int, ...],
 96    split: Literal["train", "test"],
 97    resize_inputs: bool = False,
 98    download: bool = False,
 99    **kwargs
100) -> Dataset:
101    """Get the ACDC dataset for cardiac structure segmentation.
102
103    Args:
104        path: Filepath to a folder where the data is downloaded for further processing.
105        patch_shape: The patch shape to use for training.
106        split: The choice of data split.
107        resize_inputs: Whether to resize inputs to the desired patch shape.
108        download: Whether to download the data if it is not present.
109        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
110
111    Returns:
112        The segmentation dataset.
113    """
114    image_paths, gt_paths = get_acdc_paths(path, split, download)
115
116    if resize_inputs:
117        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
118        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
119            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
120        )
121
122    all_datasets = []
123    for image_path, gt_path in zip(image_paths, gt_paths):
124        per_vol_ds = torch_em.default_segmentation_dataset(
125            raw_paths=image_path,
126            raw_key="data",
127            label_paths=gt_path,
128            label_key="data",
129            patch_shape=patch_shape,
130            is_seg_dataset=True,
131            **kwargs
132        )
133        all_datasets.append(per_vol_ds)
134
135    return ConcatDataset(*all_datasets)

Get the ACDC dataset for cardiac structure segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_acdc_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
138def get_acdc_loader(
139    path: Union[os.PathLike, str],
140    batch_size: int,
141    patch_shape: Tuple[int, ...],
142    split: Literal["train", "test"],
143    resize_inputs: bool = False,
144    download: bool = False,
145    **kwargs
146) -> DataLoader:
147    """Get the ACDC dataloader for cardiac structure segmentation.
148
149    Args:
150        path: Filepath to a folder where the data is downloaded for further processing.
151        batch_size: The batch size for training.
152        patch_shape: The patch shape to use for training.
153        split: The choice of data split.
154        resize_inputs: Whether to resize inputs to the desired patch shape.
155        download: Whether to download the data if it is not present.
156        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
157
158    Returns:
159        The DataLoader.
160    """
161    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
162    dataset = get_acdc_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
163    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the ACDC dataloader for cardiac structure segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.