torch_em.data.datasets.medical.spider

The SPIDER dataset contains annotations for segmentation of vertebrae, intervertebral discs and spinal canal in T1 and T2 MRI series.

This dataset is from the following publication:

Please cite it if you use this data in a publication.

  1"""The SPIDER dataset contains annotations for segmentation of vertebrae,
  2intervertebral discs and spinal canal in T1 and T2 MRI series.
  3
  4This dataset is from the following publication:
  5- https://zenodo.org/records/10159290
  6- https://www.nature.com/articles/s41597-024-03090-w
  7
  8Please cite it if you use this data in a publication.
  9"""
 10
 11import os
 12from glob import glob
 13from natsort import natsorted
 14from typing import Tuple, List, Union
 15
 16from torch.utils.data import Dataset, DataLoader
 17
 18import torch_em
 19
 20from .. import util
 21
 22
 23URL = {
 24    "images": "https://zenodo.org/records/10159290/files/images.zip?download=1",
 25    "masks": "https://zenodo.org/records/10159290/files/masks.zip?download=1"
 26}
 27
 28CHECKSUMS = {
 29    "images": "a54cba2905284ff6cc9999f1dd0e4d871c8487187db7cd4b068484eac2f50f17",
 30    "masks": "13a6e25a8c0d74f507e16ebb2edafc277ceeaf2598474f1fed24fdf59cb7f18f"
 31}
 32
 33
 34def get_spider_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 35    """Download the SPIDER dataset.
 36
 37    Args:
 38        path: Filepath to a folder where the data is downloaded for further processing.
 39        download: Whether to download the data if it is not present.
 40
 41    Returns:
 42        Filepath where the data is downloaded.
 43    """
 44    data_dir = os.path.join(path, "data")
 45    if os.path.exists(data_dir):
 46        return data_dir
 47
 48    os.makedirs(path, exist_ok=True)
 49
 50    zip_path = os.path.join(path, "images.zip")
 51    util.download_source(path=zip_path, url=URL["images"], download=download, checksum=CHECKSUMS["images"])
 52    util.unzip(zip_path=zip_path, dst=data_dir)
 53
 54    zip_path = os.path.join(path, "masks.zip")
 55    util.download_source(path=zip_path, url=URL["masks"], download=download, checksum=CHECKSUMS["masks"])
 56    util.unzip(zip_path=zip_path, dst=data_dir)
 57
 58    return data_dir
 59
 60
 61def get_spider_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
 62    """Get paths to the SPIDER data.
 63
 64    Args:
 65        path: Filepath to a folder where the data is downloaded for further processing.
 66        download: Whether to download the data if it is not present.
 67
 68    Returns:
 69        List of filepaths for the image data.
 70        List of filepaths for the label data.
 71    """
 72    data_dir = get_spider_data(path, download)
 73
 74    image_paths = natsorted(glob(os.path.join(data_dir, "images", "*.mha")))
 75    gt_paths = natsorted(glob(os.path.join(data_dir, "masks", "*.mha")))
 76
 77    return image_paths, gt_paths
 78
 79
 80def get_spider_dataset(
 81    path: Union[os.PathLike, str],
 82    patch_shape: Tuple[int, ...],
 83    resize_inputs: bool = False,
 84    download: bool = False,
 85    **kwargs
 86) -> Dataset:
 87    """Get the SPIDER dataset.
 88
 89    Args:
 90        path: Filepath to a folder where the data is downloaded for further processing.
 91        patch_shape: The patch shape to use for training.
 92        resize_inputs: Whether to resize inputs to the desired patch shape.
 93        download: Whether to download the data if it is not present.
 94        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 95
 96    Returns:
 97        The segmentation dataset.
 98    """
 99    # TODO: expose the choice to choose specific MRI modality, for now this works for our interests.
100    image_paths, gt_paths = get_spider_paths(path, download)
101
102    if resize_inputs:
103        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
104        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
105            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
106        )
107
108    return torch_em.default_segmentation_dataset(
109        raw_paths=image_paths,
110        raw_key=None,
111        label_paths=gt_paths,
112        label_key=None,
113        is_seg_dataset=True,
114        patch_shape=patch_shape,
115        **kwargs
116    )
117
118
119def get_spider_loader(
120    path: Union[os.PathLike, str],
121    batch_size: int,
122    patch_shape: Tuple[int, ...],
123    resize_inputs: bool = False,
124    download: bool = False,
125    **kwargs
126) -> DataLoader:
127    """Get the SPIDER dataloader.
128
129    Args:
130        path: Filepath to a folder where the data is downloaded for further processing.
131        batch_size: The batch size for training.
132        patch_shape: The patch shape to use for training.
133        resize_inputs: Whether to resize inputs to the desired patch shape.
134        download: Whether to download the data if it is not present.
135        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
136
137    Returns:
138        The DataLoader.
139    """
140    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
141    dataset = get_spider_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs)
142    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = {'images': 'https://zenodo.org/records/10159290/files/images.zip?download=1', 'masks': 'https://zenodo.org/records/10159290/files/masks.zip?download=1'}
CHECKSUMS = {'images': 'a54cba2905284ff6cc9999f1dd0e4d871c8487187db7cd4b068484eac2f50f17', 'masks': '13a6e25a8c0d74f507e16ebb2edafc277ceeaf2598474f1fed24fdf59cb7f18f'}
def get_spider_data(path: Union[os.PathLike, str], download: bool = False) -> str:
35def get_spider_data(path: Union[os.PathLike, str], download: bool = False) -> str:
36    """Download the SPIDER dataset.
37
38    Args:
39        path: Filepath to a folder where the data is downloaded for further processing.
40        download: Whether to download the data if it is not present.
41
42    Returns:
43        Filepath where the data is downloaded.
44    """
45    data_dir = os.path.join(path, "data")
46    if os.path.exists(data_dir):
47        return data_dir
48
49    os.makedirs(path, exist_ok=True)
50
51    zip_path = os.path.join(path, "images.zip")
52    util.download_source(path=zip_path, url=URL["images"], download=download, checksum=CHECKSUMS["images"])
53    util.unzip(zip_path=zip_path, dst=data_dir)
54
55    zip_path = os.path.join(path, "masks.zip")
56    util.download_source(path=zip_path, url=URL["masks"], download=download, checksum=CHECKSUMS["masks"])
57    util.unzip(zip_path=zip_path, dst=data_dir)
58
59    return data_dir

Download the SPIDER dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_spider_paths( path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
62def get_spider_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]:
63    """Get paths to the SPIDER data.
64
65    Args:
66        path: Filepath to a folder where the data is downloaded for further processing.
67        download: Whether to download the data if it is not present.
68
69    Returns:
70        List of filepaths for the image data.
71        List of filepaths for the label data.
72    """
73    data_dir = get_spider_data(path, download)
74
75    image_paths = natsorted(glob(os.path.join(data_dir, "images", "*.mha")))
76    gt_paths = natsorted(glob(os.path.join(data_dir, "masks", "*.mha")))
77
78    return image_paths, gt_paths

Get paths to the SPIDER data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_spider_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
 81def get_spider_dataset(
 82    path: Union[os.PathLike, str],
 83    patch_shape: Tuple[int, ...],
 84    resize_inputs: bool = False,
 85    download: bool = False,
 86    **kwargs
 87) -> Dataset:
 88    """Get the SPIDER dataset.
 89
 90    Args:
 91        path: Filepath to a folder where the data is downloaded for further processing.
 92        patch_shape: The patch shape to use for training.
 93        resize_inputs: Whether to resize inputs to the desired patch shape.
 94        download: Whether to download the data if it is not present.
 95        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 96
 97    Returns:
 98        The segmentation dataset.
 99    """
100    # TODO: expose the choice to choose specific MRI modality, for now this works for our interests.
101    image_paths, gt_paths = get_spider_paths(path, download)
102
103    if resize_inputs:
104        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
105        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
106            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
107        )
108
109    return torch_em.default_segmentation_dataset(
110        raw_paths=image_paths,
111        raw_key=None,
112        label_paths=gt_paths,
113        label_key=None,
114        is_seg_dataset=True,
115        patch_shape=patch_shape,
116        **kwargs
117    )

Get the SPIDER dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_spider_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
120def get_spider_loader(
121    path: Union[os.PathLike, str],
122    batch_size: int,
123    patch_shape: Tuple[int, ...],
124    resize_inputs: bool = False,
125    download: bool = False,
126    **kwargs
127) -> DataLoader:
128    """Get the SPIDER dataloader.
129
130    Args:
131        path: Filepath to a folder where the data is downloaded for further processing.
132        batch_size: The batch size for training.
133        patch_shape: The patch shape to use for training.
134        resize_inputs: Whether to resize inputs to the desired patch shape.
135        download: Whether to download the data if it is not present.
136        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
137
138    Returns:
139        The DataLoader.
140    """
141    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
142    dataset = get_spider_dataset(path, patch_shape, resize_inputs, download, **ds_kwargs)
143    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the SPIDER dataloader.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.