torch_em.data.datasets.medical.siim_acr

The SIIM ACR dataset contains annotations for pneumothorax segmentation in chest X-Rays.

This dataset is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data. The dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition: https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation. Please cite it if you use this dataset for your research.

  1"""The SIIM ACR dataset contains annotations for pneumothorax segmentation in
  2chest X-Rays.
  3
  4This dataset is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data.
  5The dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition:
  6https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation.
  7Please cite it if you use this dataset for your research.
  8"""
  9
 10import os
 11from glob import glob
 12from tqdm import tqdm
 13from natsort import natsorted
 14from typing import Union, Tuple, Literal, List
 15
 16import numpy as np
 17import imageio.v3 as imageio
 18
 19from torch.utils.data import Dataset, DataLoader
 20
 21import torch_em
 22
 23from .. import util
 24
 25
 26KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks"
 27CHECKSUM = "1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538"
 28
 29
 30def get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 31    """Download the SIIM ACR dataset.
 32
 33    Args:
 34        path: Filepath to a folder where the data is downloaded for further processing.
 35        download: Whether to download the data if it is not present.
 36
 37    Returns:
 38        Filepath where the data is downloaded.
 39    """
 40    data_dir = os.path.join(path, "siim-acr-pneumothorax")
 41    if os.path.exists(data_dir):
 42        return data_dir
 43
 44    os.makedirs(path, exist_ok=True)
 45
 46    util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download)
 47
 48    zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip")
 49    util._check_checksum(path=zip_path, checksum=CHECKSUM)
 50    util.unzip(zip_path=zip_path, dst=path)
 51
 52    return data_dir
 53
 54
 55def _clean_image_and_label_paths(image_paths, gt_paths):
 56    # NOTE: Extract paths with image and corresponding label paths with valid annotations.
 57    def _has_multiple_classes(gt_path):
 58        gt = imageio.imread(gt_path)
 59        return np.any(gt) and not np.all(gt)
 60
 61    paths = [
 62        (ip, gp) for ip, gp in tqdm(zip(image_paths, gt_paths), total=len(image_paths), desc="Verifying labels")
 63        if _has_multiple_classes(gp)
 64    ]
 65    image_paths = [p[0] for p in paths]
 66    gt_paths = [p[1] for p in paths]
 67
 68    return image_paths, gt_paths
 69
 70
 71def get_siim_acr_paths(
 72    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 73) -> Tuple[List[str], List[str]]:
 74    """Get paths to the SIIM ACR data.
 75
 76    Args:
 77        path: Filepath to a folder where the data is downloaded for further processing.
 78        split: The choice of data split.
 79        download: Whether to download the data if it is not present.
 80
 81    Returns:
 82        List of filepaths for the image data.
 83        List of filepaths for the label data.
 84    """
 85    data_dir = get_siim_acr_data(path=path, download=download)
 86
 87    if split == "test":
 88        image_paths = natsorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png")))
 89        gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png")))
 90
 91        image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths)
 92    else:
 93        image_paths = natsorted(glob(os.path.join(data_dir, "png_images", "*_train_*.png")))
 94        gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", "*_train_*.png")))
 95
 96        image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths)
 97
 98        # Next, we create custom train-val split out of the given original 'train' split.
 99        if split == "train":
100            image_paths, gt_paths = image_paths[400:], gt_paths[400:]
101        elif split == "val":
102            image_paths, gt_paths = image_paths[:400], gt_paths[:400]
103        else:
104            raise ValueError(f"'{split}' is not a valid split.")
105
106    assert len(image_paths) == len(gt_paths)
107
108    return image_paths, gt_paths
109
110
111def get_siim_acr_dataset(
112    path: Union[os.PathLike, str],
113    patch_shape: Tuple[int, int],
114    split: Literal["train", "val", "test"],
115    resize_inputs: bool = False,
116    download: bool = False,
117    **kwargs
118) -> Dataset:
119    """Get the SIIM ACR dataset for pneumothorax segmentation.
120
121    Args:
122        path: Filepath to a folder where the data is downloaded for further processing.
123        patch_shape: The patch shape to use for training.
124        split: The choice of data split.
125        resize_inputs: Whether to resize inputs to the desired patch shape.
126        download: Whether to download the data if it is not present.
127        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
128
129    Returns:
130        The segmentation dataset.
131    """
132    image_paths, gt_paths = get_siim_acr_paths(path, split, download)
133
134    if resize_inputs:
135        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
136        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
137            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
138        )
139
140    dataset = torch_em.default_segmentation_dataset(
141        raw_paths=image_paths,
142        raw_key=None,
143        label_paths=gt_paths,
144        label_key=None,
145        patch_shape=patch_shape,
146        is_seg_dataset=False,
147        **kwargs
148    )
149    dataset.max_sampling_attempts = 5000
150
151    return dataset
152
153
154def get_siim_acr_loader(
155    path: Union[os.PathLike, str],
156    batch_size: int,
157    patch_shape: Tuple[int, int],
158    split: Literal["train", "val", "test"],
159    resize_inputs: bool = False,
160    download: bool = False,
161    **kwargs
162) -> DataLoader:
163    """Get the SIIM ACR dataloader for pneumothorax segmentation.
164
165    Args:
166        path: Filepath to a folder where the data is downloaded for further processing.
167        batch_size: The batch size for training.
168        patch_shape: The patch shape to use for training.
169        split: The choice of data split.
170        resize_inputs: Whether to resize inputs to the desired patch shape.
171        download: Whether to download the data if it is not present.
172        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
173
174    Returns:
175        The DataLoader.
176    """
177    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
178    dataset = get_siim_acr_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
179    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
KAGGLE_DATASET_NAME = 'vbookshelf/pneumothorax-chest-xray-images-and-masks'
CHECKSUM = '1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538'
def get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31def get_siim_acr_data(path: Union[os.PathLike, str], download: bool = False) -> str:
32    """Download the SIIM ACR dataset.
33
34    Args:
35        path: Filepath to a folder where the data is downloaded for further processing.
36        download: Whether to download the data if it is not present.
37
38    Returns:
39        Filepath where the data is downloaded.
40    """
41    data_dir = os.path.join(path, "siim-acr-pneumothorax")
42    if os.path.exists(data_dir):
43        return data_dir
44
45    os.makedirs(path, exist_ok=True)
46
47    util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download)
48
49    zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip")
50    util._check_checksum(path=zip_path, checksum=CHECKSUM)
51    util.unzip(zip_path=zip_path, dst=path)
52
53    return data_dir

Download the SIIM ACR dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_siim_acr_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 72def get_siim_acr_paths(
 73    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 74) -> Tuple[List[str], List[str]]:
 75    """Get paths to the SIIM ACR data.
 76
 77    Args:
 78        path: Filepath to a folder where the data is downloaded for further processing.
 79        split: The choice of data split.
 80        download: Whether to download the data if it is not present.
 81
 82    Returns:
 83        List of filepaths for the image data.
 84        List of filepaths for the label data.
 85    """
 86    data_dir = get_siim_acr_data(path=path, download=download)
 87
 88    if split == "test":
 89        image_paths = natsorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png")))
 90        gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png")))
 91
 92        image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths)
 93    else:
 94        image_paths = natsorted(glob(os.path.join(data_dir, "png_images", "*_train_*.png")))
 95        gt_paths = natsorted(glob(os.path.join(data_dir, "png_masks", "*_train_*.png")))
 96
 97        image_paths, gt_paths = _clean_image_and_label_paths(image_paths, gt_paths)
 98
 99        # Next, we create custom train-val split out of the given original 'train' split.
100        if split == "train":
101            image_paths, gt_paths = image_paths[400:], gt_paths[400:]
102        elif split == "val":
103            image_paths, gt_paths = image_paths[:400], gt_paths[:400]
104        else:
105            raise ValueError(f"'{split}' is not a valid split.")
106
107    assert len(image_paths) == len(gt_paths)
108
109    return image_paths, gt_paths

Get paths to the SIIM ACR data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_siim_acr_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
112def get_siim_acr_dataset(
113    path: Union[os.PathLike, str],
114    patch_shape: Tuple[int, int],
115    split: Literal["train", "val", "test"],
116    resize_inputs: bool = False,
117    download: bool = False,
118    **kwargs
119) -> Dataset:
120    """Get the SIIM ACR dataset for pneumothorax segmentation.
121
122    Args:
123        path: Filepath to a folder where the data is downloaded for further processing.
124        patch_shape: The patch shape to use for training.
125        split: The choice of data split.
126        resize_inputs: Whether to resize inputs to the desired patch shape.
127        download: Whether to download the data if it is not present.
128        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
129
130    Returns:
131        The segmentation dataset.
132    """
133    image_paths, gt_paths = get_siim_acr_paths(path, split, download)
134
135    if resize_inputs:
136        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
137        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
138            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
139        )
140
141    dataset = torch_em.default_segmentation_dataset(
142        raw_paths=image_paths,
143        raw_key=None,
144        label_paths=gt_paths,
145        label_key=None,
146        patch_shape=patch_shape,
147        is_seg_dataset=False,
148        **kwargs
149    )
150    dataset.max_sampling_attempts = 5000
151
152    return dataset

Get the SIIM ACR dataset for pneumothorax segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_siim_acr_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
155def get_siim_acr_loader(
156    path: Union[os.PathLike, str],
157    batch_size: int,
158    patch_shape: Tuple[int, int],
159    split: Literal["train", "val", "test"],
160    resize_inputs: bool = False,
161    download: bool = False,
162    **kwargs
163) -> DataLoader:
164    """Get the SIIM ACR dataloader for pneumothorax segmentation.
165
166    Args:
167        path: Filepath to a folder where the data is downloaded for further processing.
168        batch_size: The batch size for training.
169        patch_shape: The patch shape to use for training.
170        split: The choice of data split.
171        resize_inputs: Whether to resize inputs to the desired patch shape.
172        download: Whether to download the data if it is not present.
173        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
174
175    Returns:
176        The DataLoader.
177    """
178    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
179    dataset = get_siim_acr_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
180    return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)

Get the SIIM ACR dataloader for pneumothorax segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.