torch_em.data.datasets.medical.toothfairy

The ToothSeg data contains annotations for mandibular canal (v1) and multiple structures (v2) segmentation in CBCT scans.

NOTE: The dataset is located at https://ditto.ing.unimore.it/ To download the dataset, please follow the mentioned steps:

  • Choose either v1 (https://ditto.ing.unimore.it/toothfairy) or v2 (https://ditto.ing.unimore.it/toothfairy2).
  • Visit the website, scroll down to the 'Download' section, which expects you to sign up.
  • After signing up, use your credentials to login to the dataset home page.
  • Click on the blue icon stating: 'Download Dataset' to download the zipped files to the desired path.

The relevant links for the dataset are:

Please cite them if you use this dataset for your research.

  1"""The ToothSeg data contains annotations for mandibular canal (v1) and multiple structures (v2)
  2segmentation in CBCT scans.
  3
  4NOTE: The dataset is located at https://ditto.ing.unimore.it/
  5To download the dataset, please follow the mentioned steps:
  6- Choose either v1 (https://ditto.ing.unimore.it/toothfairy) or v2 (https://ditto.ing.unimore.it/toothfairy2).
  7- Visit the website, scroll down to the 'Download' section, which expects you to sign up.
  8- After signing up, use your credentials to login to the dataset home page.
  9- Click on the blue icon stating: 'Download Dataset' to download the zipped files to the desired path.
 10
 11The relevant links for the dataset are:
 12- ToothFairy Challenge: https://toothfairy.grand-challenge.org/
 13- ToothFairy2 Challenge: https://toothfairy2.grand-challenge.org/
 14- Publication 1: https://doi.org/10.1109/ACCESS.2022.3144840
 15- Publication 2: https://doi.org/10.1109/CVPR52688.2022.02046
 16
 17Please cite them if you use this dataset for your research.
 18"""
 19
 20import os
 21from glob import glob
 22from tqdm import tqdm
 23from natsort import natsorted
 24from typing import Union, Tuple, Literal, List
 25
 26import numpy as np
 27
 28from torch.utils.data import Dataset, DataLoader
 29
 30import torch_em
 31
 32from .. import util
 33
 34
 35def get_toothfairy_data(
 36    path: Union[os.PathLike, str], version: Literal["v1", "v2"] = "v2", download: bool = False
 37) -> str:
 38    """Obtain the ToothFairy datasets.
 39
 40    Args:
 41        path: Filepath to a folder where the data is downloaded for further processing.
 42        version: The version of dataset. Either v1 (ToothFairy) or v2 (ToothFairy2).
 43        download: Whether to download the data if it is not present.
 44
 45    Returns:
 46        Filepath to the already downloaded dataset.
 47    """
 48    data_dir = os.path.join(path, "ToothFairy_Dataset/Dataset" if version == "v1" else "Dataset112_ToothFairy2")
 49    if os.path.exists(data_dir):
 50        return data_dir
 51
 52    if download:
 53        msg = "Download is set to True, but 'torch_em' cannot download this dataset. "
 54        msg += "See `get_toothfairy2_data` for details."
 55        raise NotImplementedError(msg)
 56
 57    if version == "v1":
 58        zip_path = os.path.join(path, "ToothFairy_Dataset.zip")
 59    elif version == "v2":
 60        zip_path = os.path.join(path, "ToothFairy2_Dataset.zip")
 61    else:
 62        raise ValueError(f"'{version}' is not a valid version.")
 63
 64    if not os.path.exists(zip_path):
 65        raise FileNotFoundError(f"It's expected to place the downloaded toothfairy zipfile at '{path}'.")
 66
 67    util.unzip(zip_path=zip_path, dst=path, remove=False)
 68
 69    return data_dir
 70
 71
 72def _preprocess_toothfairy_inputs(path, data_dir):
 73    import nibabel as nib
 74
 75    images_dir = os.path.join(path, "data", "images")
 76    gt_dir = os.path.join(path, "data", "dense_labels")
 77    if os.path.exists(images_dir) and os.path.exists(gt_dir):
 78        return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz")))
 79
 80    os.makedirs(images_dir, exist_ok=True)
 81    os.makedirs(gt_dir, exist_ok=True)
 82
 83    image_paths, gt_paths = [], []
 84    for patient_dir in tqdm(glob(os.path.join(data_dir, "P*")), desc="Preprocessing inputs"):
 85        dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy")
 86        if not os.path.exists(dense_anns_path):
 87            continue
 88
 89        image_path = os.path.join(patient_dir, "data.npy")
 90        image, gt = np.load(image_path), np.load(dense_anns_path)
 91        image_nifti, gt_nifti = nib.Nifti2Image(image, np.eye(4)), nib.Nifti2Image(gt, np.eye(4))
 92
 93        patient_id = os.path.split(patient_dir)[-1]
 94        trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz")
 95        trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz")
 96
 97        nib.save(image_nifti, trg_image_path)
 98        nib.save(gt_nifti, trg_gt_path)
 99
100        image_paths.append(trg_image_path)
101        gt_paths.append(trg_gt_path)
102
103    return image_paths, gt_paths
104
105
106def get_toothfairy_paths(
107    path: Union[os.PathLike, str],
108    split: Literal['train', 'val', 'test'],
109    version: Literal["v1", "v2"] = "v2",
110    download: bool = False,
111) -> Tuple[List[str], List[str]]:
112    """Get paths to the ToothFairy data.
113
114    Args:
115        path: Filepath to a folder where the data is downloaded for further processing.
116        split: The choice of data split.
117        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
118        download: Whether to download the data if it is not present.
119
120    Returns:
121        List of filepaths for the image data.
122        List of filepaths for the label data.
123    """
124    data_dir = get_toothfairy_data(path, version, download)
125
126    if version == "v1":
127        image_paths, gt_paths = _preprocess_toothfairy_inputs(path, data_dir)
128
129        if split == "train":
130            image_paths, gt_paths = image_paths[:100], gt_paths[:100]
131        elif split == "val":
132            image_paths, gt_paths = image_paths[100:125], gt_paths[100:125]
133        elif split == "test":
134            image_paths, gt_paths = image_paths[125:], gt_paths[125:]
135        else:
136            raise ValueError(f"'{split}' is not a valid split.")
137
138    else:
139        image_paths = natsorted(glob(os.path.join(data_dir, "imagesTr", "*.mha")))
140        gt_paths = natsorted(glob(os.path.join(data_dir, "labelsTr", "*.mha")))
141
142        if split == "train":
143            image_paths, gt_paths = image_paths[:400], gt_paths[:400]
144        elif split == "val":
145            image_paths, gt_paths = image_paths[400:425], gt_paths[400:425]
146        elif split == "test":
147            image_paths, gt_paths = image_paths[425:], gt_paths[425:]
148        else:
149            raise ValueError(f"'{split}' is not a valid split.")
150
151    return image_paths, gt_paths
152
153
154def get_toothfairy_dataset(
155    path: Union[os.PathLike, str],
156    patch_shape: Tuple[int, ...],
157    split: Literal['train', 'val', 'test'],
158    version: Literal["v1", "v2"] = "v2",
159    resize_inputs: bool = False,
160    download: bool = False,
161    **kwargs
162) -> Dataset:
163    """Get the ToothFairy dataset for canal and teeth segmentation.
164
165    Args:
166        path: Filepath to a folder where the data is downloaded for further processing.
167        patch_shape: The patch shape to use for training.
168        split: The choice of data split.
169        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
170        resize_inputs: Whether to resize inputs to the desired patch shape.
171        download: Whether to download the data if it is not present.
172        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
173
174    Returns:
175        The segmentation dataset.
176    """
177    image_paths, gt_paths = get_toothfairy_paths(path, split, version, download)
178
179    if resize_inputs:
180        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
181        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
182            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
183        )
184
185    return torch_em.default_segmentation_dataset(
186        raw_paths=image_paths,
187        raw_key="data" if version == "v1" else None,
188        label_paths=gt_paths,
189        label_key="data" if version == "v1" else None,
190        is_seg_dataset=True,
191        patch_shape=patch_shape,
192        **kwargs
193    )
194
195
196def get_toothfairy_loader(
197    path: Union[os.PathLike, str],
198    batch_size: int,
199    patch_shape: Tuple[int, ...],
200    split: Literal['train', 'val', 'test'],
201    version: Literal["v1", "v2"] = "v2",
202    resize_inputs: bool = False,
203    download: bool = False,
204    **kwargs
205) -> DataLoader:
206    """Get the ToothFairy dataloader for canal and teeth segmentation.
207
208    Args:
209        path: Filepath to a folder where the data is downloaded for further processing.
210        batch_size: The batch size for training.
211        patch_shape: The patch shape to use for training.
212        split: The choice of data split.
213        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
214        resize_inputs: Whether to resize inputs to the desired patch shape.
215        download: Whether to download the data if it is not present.
216        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
217
218    Returns:
219        The DataLoader.
220    """
221    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
222    dataset = get_toothfairy_dataset(path, patch_shape, split, version, resize_inputs, download, **ds_kwargs)
223    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def get_toothfairy_data( path: Union[os.PathLike, str], version: Literal['v1', 'v2'] = 'v2', download: bool = False) -> str:
36def get_toothfairy_data(
37    path: Union[os.PathLike, str], version: Literal["v1", "v2"] = "v2", download: bool = False
38) -> str:
39    """Obtain the ToothFairy datasets.
40
41    Args:
42        path: Filepath to a folder where the data is downloaded for further processing.
43        version: The version of dataset. Either v1 (ToothFairy) or v2 (ToothFairy2).
44        download: Whether to download the data if it is not present.
45
46    Returns:
47        Filepath to the already downloaded dataset.
48    """
49    data_dir = os.path.join(path, "ToothFairy_Dataset/Dataset" if version == "v1" else "Dataset112_ToothFairy2")
50    if os.path.exists(data_dir):
51        return data_dir
52
53    if download:
54        msg = "Download is set to True, but 'torch_em' cannot download this dataset. "
55        msg += "See `get_toothfairy2_data` for details."
56        raise NotImplementedError(msg)
57
58    if version == "v1":
59        zip_path = os.path.join(path, "ToothFairy_Dataset.zip")
60    elif version == "v2":
61        zip_path = os.path.join(path, "ToothFairy2_Dataset.zip")
62    else:
63        raise ValueError(f"'{version}' is not a valid version.")
64
65    if not os.path.exists(zip_path):
66        raise FileNotFoundError(f"It's expected to place the downloaded toothfairy zipfile at '{path}'.")
67
68    util.unzip(zip_path=zip_path, dst=path, remove=False)
69
70    return data_dir

Obtain the ToothFairy datasets.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • version: The version of dataset. Either v1 (ToothFairy) or v2 (ToothFairy2).
  • download: Whether to download the data if it is not present.
Returns:

Filepath to the already downloaded dataset.

def get_toothfairy_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], version: Literal['v1', 'v2'] = 'v2', download: bool = False) -> Tuple[List[str], List[str]]:
107def get_toothfairy_paths(
108    path: Union[os.PathLike, str],
109    split: Literal['train', 'val', 'test'],
110    version: Literal["v1", "v2"] = "v2",
111    download: bool = False,
112) -> Tuple[List[str], List[str]]:
113    """Get paths to the ToothFairy data.
114
115    Args:
116        path: Filepath to a folder where the data is downloaded for further processing.
117        split: The choice of data split.
118        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
119        download: Whether to download the data if it is not present.
120
121    Returns:
122        List of filepaths for the image data.
123        List of filepaths for the label data.
124    """
125    data_dir = get_toothfairy_data(path, version, download)
126
127    if version == "v1":
128        image_paths, gt_paths = _preprocess_toothfairy_inputs(path, data_dir)
129
130        if split == "train":
131            image_paths, gt_paths = image_paths[:100], gt_paths[:100]
132        elif split == "val":
133            image_paths, gt_paths = image_paths[100:125], gt_paths[100:125]
134        elif split == "test":
135            image_paths, gt_paths = image_paths[125:], gt_paths[125:]
136        else:
137            raise ValueError(f"'{split}' is not a valid split.")
138
139    else:
140        image_paths = natsorted(glob(os.path.join(data_dir, "imagesTr", "*.mha")))
141        gt_paths = natsorted(glob(os.path.join(data_dir, "labelsTr", "*.mha")))
142
143        if split == "train":
144            image_paths, gt_paths = image_paths[:400], gt_paths[:400]
145        elif split == "val":
146            image_paths, gt_paths = image_paths[400:425], gt_paths[400:425]
147        elif split == "test":
148            image_paths, gt_paths = image_paths[425:], gt_paths[425:]
149        else:
150            raise ValueError(f"'{split}' is not a valid split.")
151
152    return image_paths, gt_paths

Get paths to the ToothFairy data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_toothfairy_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], version: Literal['v1', 'v2'] = 'v2', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
155def get_toothfairy_dataset(
156    path: Union[os.PathLike, str],
157    patch_shape: Tuple[int, ...],
158    split: Literal['train', 'val', 'test'],
159    version: Literal["v1", "v2"] = "v2",
160    resize_inputs: bool = False,
161    download: bool = False,
162    **kwargs
163) -> Dataset:
164    """Get the ToothFairy dataset for canal and teeth segmentation.
165
166    Args:
167        path: Filepath to a folder where the data is downloaded for further processing.
168        patch_shape: The patch shape to use for training.
169        split: The choice of data split.
170        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
171        resize_inputs: Whether to resize inputs to the desired patch shape.
172        download: Whether to download the data if it is not present.
173        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
174
175    Returns:
176        The segmentation dataset.
177    """
178    image_paths, gt_paths = get_toothfairy_paths(path, split, version, download)
179
180    if resize_inputs:
181        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
182        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
183            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
184        )
185
186    return torch_em.default_segmentation_dataset(
187        raw_paths=image_paths,
188        raw_key="data" if version == "v1" else None,
189        label_paths=gt_paths,
190        label_key="data" if version == "v1" else None,
191        is_seg_dataset=True,
192        patch_shape=patch_shape,
193        **kwargs
194    )

Get the ToothFairy dataset for canal and teeth segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_toothfairy_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], version: Literal['v1', 'v2'] = 'v2', resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
197def get_toothfairy_loader(
198    path: Union[os.PathLike, str],
199    batch_size: int,
200    patch_shape: Tuple[int, ...],
201    split: Literal['train', 'val', 'test'],
202    version: Literal["v1", "v2"] = "v2",
203    resize_inputs: bool = False,
204    download: bool = False,
205    **kwargs
206) -> DataLoader:
207    """Get the ToothFairy dataloader for canal and teeth segmentation.
208
209    Args:
210        path: Filepath to a folder where the data is downloaded for further processing.
211        batch_size: The batch size for training.
212        patch_shape: The patch shape to use for training.
213        split: The choice of data split.
214        version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
215        resize_inputs: Whether to resize inputs to the desired patch shape.
216        download: Whether to download the data if it is not present.
217        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
218
219    Returns:
220        The DataLoader.
221    """
222    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
223    dataset = get_toothfairy_dataset(path, patch_shape, split, version, resize_inputs, download, **ds_kwargs)
224    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the ToothFairy dataloader for canal and teeth segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • version: The version of dataset. Either 'v1' (ToothFairy) or 'v2' (ToothFairy2).
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.