torch_em.data.datasets.medical.siim_acr
The SIIM ACR dataset contains annotations for pneumothorax segmentation in chest X-Rays.
This dataset is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data. The dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition: https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation. Please cite it if you use this dataset for your research.
1"""The SIIM ACR dataset contains annotations for pneumothorax segmentation in 2chest X-Rays. 3 4This dataset is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data. 5The dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition: 6https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation. 7Please cite it if you use this dataset for your research. 8""" 9 10import os 11from glob import glob 12from tqdm import tqdm 13from natsort import natsorted 14from typing import Union, Tuple, Literal, List 15 16import numpy as np 17import imageio.v3 as imageio 18 19from torch.utils.data import Dataset, DataLoader 20 21import torch_em 22 23from .. import util 24 25 26KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks" 27CHECKSUM = "1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538" 28 29 30def get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str: 31 """Download the SIIM ACR dataset. 32 33 Args: 34 path: Filepath to a folder where the data is downloaded for further processing. 35 download: Whether to download the data if it is not present. 36 37 Returns: 38 Filepath where the data is downloaded. 39 """ 40 data_dir = os.path.join(path, "siim-acr-pneumothorax") 41 if os.path.exists(data_dir): 42 return data_dir 43 44 os.makedirs(path, exist_ok=True) 45 46 util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download) 47 48 zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip") 49 util._check_checksum(path=zip_path, checksum=CHECKSUM) 50 util.unzip(zip_path=zip_path, dst=path) 51 52 return data_dir 53 54 55def _clean_image_and_label_paths(image_paths, gt_paths): 56 # NOTE: Extract paths with image and corresponding label paths with valid annotations. 57 def _has_multiple_classes(gt_path): 58 gt = imageio.imread(gt_path) 59 return np.any(gt) and not np.all(gt) 60 61 paths = [ 62 (ip, gp) for ip, gp in tqdm(zip(image_paths, gt_paths), total=len(image_paths), desc="Verifying labels") 63 if _has_multiple_classes(gp) 64 ] 65 image_paths = [p[0] for p in paths] 66 gt_paths = [p[1] for p in paths] 67 68 return image_paths, gt_paths 69 70 71def get_siim_acr_paths( 72 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 73) -> Tuple[List[str], List[str]]: 74 """Get paths to the SIIM ACR data. 75 76 Args: 77 path: Filepath to a folder where the data is downloaded for further processing. 78 split: The choice of data split. 79 download: Whether to download the data if it is not present. 80 81 Returns: 82 List of filepaths for the image data. 83 List of filepaths for the label data. 84 """ 85 data_dir = get_siim_acr_data(path=path, download=download) 86 87 if split == "test": 88 image_paths = natsorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png"))) 89 gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png"))) 90 91 image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths) 92 else: 93 image_paths = natsorted(glob(os.path.join(data_dir, "png_images", "*_train_*.png"))) 94 gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", "*_train_*.png"))) 95 96 image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths) 97 98 # Next, we create custom train-val split out of the given original 'train' split. 99 if split == "train": 100 image_paths, gt_paths = image_paths[400:], gt_paths[400:] 101 elif split == "val": 102 image_paths, gt_paths = image_paths[:400], gt_paths[:400] 103 else: 104 raise ValueError(f"'{split}' is not a valid split.") 105 106 assert len(image_paths) == len(gt_paths) 107 108 return image_paths, gt_paths 109 110 111def get_siim_acr_dataset( 112 path: Union[os.PathLike, str], 113 patch_shape: Tuple[int, int], 114 split: Literal["train", "val", "test"], 115 resize_inputs: bool = False, 116 download: bool = False, 117 **kwargs 118) -> Dataset: 119 """Get the SIIM ACR dataset for pneumothorax segmentation. 120 121 Args: 122 path: Filepath to a folder where the data is downloaded for further processing. 123 patch_shape: The patch shape to use for training. 124 split: The choice of data split. 125 resize_inputs: Whether to resize inputs to the desired patch shape. 126 download: Whether to download the data if it is not present. 127 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 128 129 Returns: 130 The segmentation dataset. 131 """ 132 image_paths, gt_paths = get_siim_acr_paths(path, split, download) 133 134 if resize_inputs: 135 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 136 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 137 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 138 ) 139 140 dataset = torch_em.default_segmentation_dataset( 141 raw_paths=image_paths, 142 raw_key=None, 143 label_paths=gt_paths, 144 label_key=None, 145 patch_shape=patch_shape, 146 is_seg_dataset=False, 147 **kwargs 148 ) 149 dataset.max_sampling_attempts = 5000 150 151 return dataset 152 153 154def get_siim_acr_loader( 155 path: Union[os.PathLike, str], 156 batch_size: int, 157 patch_shape: Tuple[int, int], 158 split: Literal["train", "val", "test"], 159 resize_inputs: bool = False, 160 download: bool = False, 161 **kwargs 162) -> DataLoader: 163 """Get the SIIM ACR dataloader for pneumothorax segmentation. 164 165 Args: 166 path: Filepath to a folder where the data is downloaded for further processing. 167 batch_size: The batch size for training. 168 patch_shape: The patch shape to use for training. 169 split: The choice of data split. 170 resize_inputs: Whether to resize inputs to the desired patch shape. 171 download: Whether to download the data if it is not present. 172 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 173 174 Returns: 175 The DataLoader. 176 """ 177 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 178 dataset = get_siim_acr_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 179 return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
KAGGLE_DATASET_NAME =
'vbookshelf/pneumothorax-chest-xray-images-and-masks'
CHECKSUM =
'1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538'
def
get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str: 32 """Download the SIIM ACR dataset. 33 34 Args: 35 path: Filepath to a folder where the data is downloaded for further processing. 36 download: Whether to download the data if it is not present. 37 38 Returns: 39 Filepath where the data is downloaded. 40 """ 41 data_dir = os.path.join(path, "siim-acr-pneumothorax") 42 if os.path.exists(data_dir): 43 return data_dir 44 45 os.makedirs(path, exist_ok=True) 46 47 util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download) 48 49 zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip") 50 util._check_checksum(path=zip_path, checksum=CHECKSUM) 51 util.unzip(zip_path=zip_path, dst=path) 52 53 return data_dir
Download the SIIM ACR dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_siim_acr_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
72def get_siim_acr_paths( 73 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 74) -> Tuple[List[str], List[str]]: 75 """Get paths to the SIIM ACR data. 76 77 Args: 78 path: Filepath to a folder where the data is downloaded for further processing. 79 split: The choice of data split. 80 download: Whether to download the data if it is not present. 81 82 Returns: 83 List of filepaths for the image data. 84 List of filepaths for the label data. 85 """ 86 data_dir = get_siim_acr_data(path=path, download=download) 87 88 if split == "test": 89 image_paths = natsorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png"))) 90 gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png"))) 91 92 image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths) 93 else: 94 image_paths = natsorted(glob(os.path.join(data_dir, "png_images", "*_train_*.png"))) 95 gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", "*_train_*.png"))) 96 97 image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths) 98 99 # Next, we create custom train-val split out of the given original 'train' split. 100 if split == "train": 101 image_paths, gt_paths = image_paths[400:], gt_paths[400:] 102 elif split == "val": 103 image_paths, gt_paths = image_paths[:400], gt_paths[:400] 104 else: 105 raise ValueError(f"'{split}' is not a valid split.") 106 107 assert len(image_paths) == len(gt_paths) 108 109 return image_paths, gt_paths
Get paths to the SIIM ACR data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_siim_acr_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
112def get_siim_acr_dataset( 113 path: Union[os.PathLike, str], 114 patch_shape: Tuple[int, int], 115 split: Literal["train", "val", "test"], 116 resize_inputs: bool = False, 117 download: bool = False, 118 **kwargs 119) -> Dataset: 120 """Get the SIIM ACR dataset for pneumothorax segmentation. 121 122 Args: 123 path: Filepath to a folder where the data is downloaded for further processing. 124 patch_shape: The patch shape to use for training. 125 split: The choice of data split. 126 resize_inputs: Whether to resize inputs to the desired patch shape. 127 download: Whether to download the data if it is not present. 128 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 129 130 Returns: 131 The segmentation dataset. 132 """ 133 image_paths, gt_paths = get_siim_acr_paths(path, split, download) 134 135 if resize_inputs: 136 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 137 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 138 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 139 ) 140 141 dataset = torch_em.default_segmentation_dataset( 142 raw_paths=image_paths, 143 raw_key=None, 144 label_paths=gt_paths, 145 label_key=None, 146 patch_shape=patch_shape, 147 is_seg_dataset=False, 148 **kwargs 149 ) 150 dataset.max_sampling_attempts = 5000 151 152 return dataset
Get the SIIM ACR dataset for pneumothorax segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_siim_acr_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
155def get_siim_acr_loader( 156 path: Union[os.PathLike, str], 157 batch_size: int, 158 patch_shape: Tuple[int, int], 159 split: Literal["train", "val", "test"], 160 resize_inputs: bool = False, 161 download: bool = False, 162 **kwargs 163) -> DataLoader: 164 """Get the SIIM ACR dataloader for pneumothorax segmentation. 165 166 Args: 167 path: Filepath to a folder where the data is downloaded for further processing. 168 batch_size: The batch size for training. 169 patch_shape: The patch shape to use for training. 170 split: The choice of data split. 171 resize_inputs: Whether to resize inputs to the desired patch shape. 172 download: Whether to download the data if it is not present. 173 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 174 175 Returns: 176 The DataLoader. 177 """ 178 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 179 dataset = get_siim_acr_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 180 return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
Get the SIIM ACR dataloader for pneumothorax segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.