torch_em.data.datasets.medical.acdc
The ACDC dataset contains annotations for multi-structure segmentation in cardiac MRI.
The labels have the following mapping:
- 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity)
The database is located at https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb
The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502. Please cite it if you use this dataset for a publication.
1"""The ACDC dataset contains annotations for multi-structure segmentation in cardiac MRI. 2 3The labels have the following mapping: 4- 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity) 5 6The database is located at 7https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb 8 9The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502. 10Please cite it if you use this dataset for a publication. 11""" 12 13import os 14from glob import glob 15from natsort import natsorted 16from typing import Union, Tuple, Literal, List 17 18from torch.utils.data import Dataset, DataLoader 19 20import torch_em 21 22from .. import util 23from ... import ConcatDataset 24 25 26URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download" 27CHECKSUM = "2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e" 28 29 30def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str: 31 """Download the ACDC dataset. 32 33 Args: 34 path: Filepath to a folder where the data is downloaded for further processing. 35 download: Whether to download the data if it is not present. 36 37 Returns: 38 Filepath where the data is downlaoded. 39 """ 40 zip_path = os.path.join(path, "ACDC.zip") 41 trg_dir = os.path.join(path, "ACDC") 42 if os.path.exists(trg_dir): 43 return trg_dir 44 45 os.makedirs(path, exist_ok=True) 46 47 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 48 util.unzip(zip_path=zip_path, dst=path, remove=False) 49 50 return trg_dir 51 52 53def get_acdc_paths( 54 path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False 55) -> Tuple[List[str], List[str]]: 56 """Get paths to the ACDC data. 57 58 Args: 59 path: Filepath to a folder where the data is downloaded for further processing. 60 download: Whether to download the data if it is not present. 61 62 Returns: 63 List of filepaths for the image data. 64 List of filepaths for the label data. 65 """ 66 root_dir = get_acdc_data(path=path, download=download) 67 68 if split == "train": 69 input_dir = os.path.join(root_dir, "database", "training") 70 elif split == "test": 71 input_dir = os.path.join(root_dir, "database", "testing") 72 else: 73 raise ValueError(f"'{split}' is not a valid data split.") 74 75 all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*"))) 76 77 image_paths, gt_paths = [], [] 78 for per_patient_dir in all_patient_dirs: 79 # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES)) 80 # the "frames" denote - ED and ES phase instances, which have manual segmentations. 81 all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz")) 82 for vol_path in all_volumes: 83 sres = vol_path.find("gt") 84 if sres == -1: # this means the search was invalid, hence it's the mri volume 85 image_paths.append(vol_path) 86 else: # this means that the search went through, hence it's the ground truth volume 87 gt_paths.append(vol_path) 88 89 return natsorted(image_paths), natsorted(gt_paths) 90 91 92def get_acdc_dataset( 93 path: Union[os.PathLike, str], 94 patch_shape: Tuple[int, ...], 95 split: Literal["train", "test"], 96 resize_inputs: bool = False, 97 download: bool = False, 98 **kwargs 99) -> Dataset: 100 """Get the ACDC dataset for cardiac structure segmentation. 101 102 Args: 103 path: Filepath to a folder where the data is downloaded for further processing. 104 patch_shape: The patch shape to use for training. 105 split: The choice of data split. 106 resize_inputs: Whether to resize inputs to the desired patch shape. 107 download: Whether to download the data if it is not present. 108 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 109 110 Returns: 111 The segmentation dataset. 112 """ 113 image_paths, gt_paths = get_acdc_paths(path, split, download) 114 115 if resize_inputs: 116 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 117 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 118 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 119 ) 120 121 all_datasets = [] 122 for image_path, gt_path in zip(image_paths, gt_paths): 123 per_vol_ds = torch_em.default_segmentation_dataset( 124 raw_paths=image_path, 125 raw_key="data", 126 label_paths=gt_path, 127 label_key="data", 128 patch_shape=patch_shape, 129 is_seg_dataset=True, 130 **kwargs 131 ) 132 all_datasets.append(per_vol_ds) 133 134 return ConcatDataset(*all_datasets) 135 136 137def get_acdc_loader( 138 path: Union[os.PathLike, str], 139 batch_size: int, 140 patch_shape: Tuple[int, ...], 141 split: Literal["train", "test"], 142 resize_inputs: bool = False, 143 download: bool = False, 144 **kwargs 145) -> DataLoader: 146 """Get the ACDC dataloader for cardiac structure segmentation. 147 148 Args: 149 path: Filepath to a folder where the data is downloaded for further processing. 150 batch_size: The batch size for training. 151 patch_shape: The patch shape to use for training. 152 split: The choice of data split. 153 resize_inputs: Whether to resize inputs to the desired patch shape. 154 download: Whether to download the data if it is not present. 155 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 156 157 Returns: 158 The DataLoader. 159 """ 160 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 161 dataset = get_acdc_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 162 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download'
CHECKSUM =
'2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e'
def
get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_acdc_data(path: Union[os.PathLike, str], download: bool = False) -> str: 32 """Download the ACDC dataset. 33 34 Args: 35 path: Filepath to a folder where the data is downloaded for further processing. 36 download: Whether to download the data if it is not present. 37 38 Returns: 39 Filepath where the data is downlaoded. 40 """ 41 zip_path = os.path.join(path, "ACDC.zip") 42 trg_dir = os.path.join(path, "ACDC") 43 if os.path.exists(trg_dir): 44 return trg_dir 45 46 os.makedirs(path, exist_ok=True) 47 48 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 49 util.unzip(zip_path=zip_path, dst=path, remove=False) 50 51 return trg_dir
Download the ACDC dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downlaoded.
def
get_acdc_paths( path: Union[os.PathLike, str], split: Literal['train', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
54def get_acdc_paths( 55 path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False 56) -> Tuple[List[str], List[str]]: 57 """Get paths to the ACDC data. 58 59 Args: 60 path: Filepath to a folder where the data is downloaded for further processing. 61 download: Whether to download the data if it is not present. 62 63 Returns: 64 List of filepaths for the image data. 65 List of filepaths for the label data. 66 """ 67 root_dir = get_acdc_data(path=path, download=download) 68 69 if split == "train": 70 input_dir = os.path.join(root_dir, "database", "training") 71 elif split == "test": 72 input_dir = os.path.join(root_dir, "database", "testing") 73 else: 74 raise ValueError(f"'{split}' is not a valid data split.") 75 76 all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*"))) 77 78 image_paths, gt_paths = [], [] 79 for per_patient_dir in all_patient_dirs: 80 # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES)) 81 # the "frames" denote - ED and ES phase instances, which have manual segmentations. 82 all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz")) 83 for vol_path in all_volumes: 84 sres = vol_path.find("gt") 85 if sres == -1: # this means the search was invalid, hence it's the mri volume 86 image_paths.append(vol_path) 87 else: # this means that the search went through, hence it's the ground truth volume 88 gt_paths.append(vol_path) 89 90 return natsorted(image_paths), natsorted(gt_paths)
Get paths to the ACDC data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_acdc_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
93def get_acdc_dataset( 94 path: Union[os.PathLike, str], 95 patch_shape: Tuple[int, ...], 96 split: Literal["train", "test"], 97 resize_inputs: bool = False, 98 download: bool = False, 99 **kwargs 100) -> Dataset: 101 """Get the ACDC dataset for cardiac structure segmentation. 102 103 Args: 104 path: Filepath to a folder where the data is downloaded for further processing. 105 patch_shape: The patch shape to use for training. 106 split: The choice of data split. 107 resize_inputs: Whether to resize inputs to the desired patch shape. 108 download: Whether to download the data if it is not present. 109 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 110 111 Returns: 112 The segmentation dataset. 113 """ 114 image_paths, gt_paths = get_acdc_paths(path, split, download) 115 116 if resize_inputs: 117 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 118 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 119 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 120 ) 121 122 all_datasets = [] 123 for image_path, gt_path in zip(image_paths, gt_paths): 124 per_vol_ds = torch_em.default_segmentation_dataset( 125 raw_paths=image_path, 126 raw_key="data", 127 label_paths=gt_path, 128 label_key="data", 129 patch_shape=patch_shape, 130 is_seg_dataset=True, 131 **kwargs 132 ) 133 all_datasets.append(per_vol_ds) 134 135 return ConcatDataset(*all_datasets)
Get the ACDC dataset for cardiac structure segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_acdc_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
138def get_acdc_loader( 139 path: Union[os.PathLike, str], 140 batch_size: int, 141 patch_shape: Tuple[int, ...], 142 split: Literal["train", "test"], 143 resize_inputs: bool = False, 144 download: bool = False, 145 **kwargs 146) -> DataLoader: 147 """Get the ACDC dataloader for cardiac structure segmentation. 148 149 Args: 150 path: Filepath to a folder where the data is downloaded for further processing. 151 batch_size: The batch size for training. 152 patch_shape: The patch shape to use for training. 153 split: The choice of data split. 154 resize_inputs: Whether to resize inputs to the desired patch shape. 155 download: Whether to download the data if it is not present. 156 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 157 158 Returns: 159 The DataLoader. 160 """ 161 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 162 dataset = get_acdc_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 163 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the ACDC dataloader for cardiac structure segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.