torch_em.data.datasets.medical.hil_toothseg

The HIL ToothSeg dataset contains annotations for teeth segmentation in panoramic dental radiographs.

This dataset is from the publication https://www.mdpi.com/1424-8220/21/9/3110. Please cite it if you use this dataset for your research.

  1"""The HIL ToothSeg dataset contains annotations for teeth segmentation
  2in panoramic dental radiographs.
  3
  4This dataset is from the publication https://www.mdpi.com/1424-8220/21/9/3110.
  5Please cite it if you use this dataset for your research.
  6"""
  7
  8import os
  9from glob import glob
 10from tqdm import tqdm
 11from pathlib import Path
 12from natsort import natsorted
 13from typing import Union, Literal, Tuple, List
 14
 15import numpy as np
 16import imageio.v3 as imageio
 17
 18from torch.utils.data import Dataset, DataLoader
 19
 20import torch_em
 21
 22from .. import util
 23
 24
 25URL = "https://hitl-public-datasets.s3.eu-central-1.amazonaws.com/Teeth+Segmentation.zip"
 26CHECKSUM = "3b628165a218a5e8d446d1313e6ecbe7cfc599a3d6418cd60b4fb78745becc2e"
 27
 28
 29def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 30    """Download the HIL ToothSeg dataset.
 31
 32    Args:
 33        path: Filepath to a folder where the data is downloaded for further processing.
 34        download: Whether to download the data if it is not present.
 35    """
 36    data_dir = os.path.join(path, r"Teeth Segmentation PNG")
 37    if os.path.exists(data_dir):
 38        return data_dir
 39
 40    os.makedirs(path, exist_ok=True)
 41
 42    zip_path = os.path.join(path, "Teeth_Segmentation.zip")
 43    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 44    util.unzip(zip_path=zip_path, dst=path)
 45
 46    return data_dir
 47
 48
 49def get_hil_toothseg_paths(
 50    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 51) -> Tuple[List[str], List[str]]:
 52    """Get paths to the HIL ToothSeg data.
 53
 54    Args:
 55        path: Filepath to a folder where the data is downloaded for further processing.
 56        split: The data split to use. Either 'train', 'val' or 'test'.
 57        download: Whether to download the data if it is not present.
 58
 59    Returns:
 60        List of filepaths for the image data.
 61        List of filepaths for the label data.
 62    """
 63    import cv2 as cv
 64
 65    data_dir = get_hil_toothseg_data(path=path, download=download)
 66
 67    image_paths = natsorted(glob(os.path.join(data_dir, "d2", "img", "*")))
 68    gt_paths = natsorted(glob(os.path.join(data_dir, "d2", "masks_machine", "*")))
 69
 70    neu_gt_dir = os.path.join(data_dir, "preprocessed", "gt")
 71    os.makedirs(neu_gt_dir, exist_ok=True)
 72
 73    neu_gt_paths = []
 74    for gt_path in tqdm(gt_paths, desc="Preprocessing inputs"):
 75        neu_gt_path = os.path.join(neu_gt_dir, f"{Path(gt_path).stem}.tif")
 76        neu_gt_paths.append(neu_gt_path)
 77        if os.path.exists(neu_gt_path):
 78            continue
 79
 80        rgb_gt = cv.imread(gt_path)
 81        rgb_gt = cv.cvtColor(rgb_gt, cv.COLOR_BGR2RGB)
 82        incolors = np.unique(rgb_gt.reshape(-1, rgb_gt.shape[2]), axis=0)
 83
 84        # the first id is always background, let's remove it
 85        if np.array_equal(incolors[0], np.array([0, 0, 0])):
 86            incolors = incolors[1:]
 87
 88        instances = np.zeros(rgb_gt.shape[:2])
 89
 90        color_to_id = {tuple(cvalue): i for i, cvalue in enumerate(incolors, start=1)}
 91        for cvalue, idx in color_to_id.items():
 92            binary_map = (rgb_gt == cvalue).all(axis=2)
 93            instances[binary_map] = idx
 94
 95        imageio.imwrite(neu_gt_path, instances)
 96
 97    if split == "train":
 98        image_paths, neu_gt_paths = image_paths[:450], neu_gt_paths[:450]
 99    elif split == "val":
100        image_paths, neu_gt_paths = image_paths[425:475], neu_gt_paths[425:475]
101    elif split == "test":
102        image_paths, neu_gt_paths = image_paths[475:], neu_gt_paths[475:]
103    else:
104        raise ValueError(f"{split} is not a valid split.")
105
106    return image_paths, neu_gt_paths
107
108
109def get_hil_toothseg_dataset(
110    path: Union[os.PathLike, str],
111    patch_shape: Tuple[int, int],
112    split: Literal["train", "val", "test"],
113    resize_inputs: bool = False,
114    download: bool = False,
115    **kwargs
116) -> Dataset:
117    """Get the HIL ToothSeg dataset for teeth segmentation.
118
119    Args:
120        path: Filepath to a folder where the data is downloaded for further processing.
121        patch_shape: The patch shape to use for training.
122        split: The data split to use. Either 'train', 'val' or 'test'.
123        resize_inputs: Whether to resize the inputs to the patch shape.
124        download: Whether to download the data if it is not present.
125        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
126
127    Returns:
128        The segmentation dataset.
129    """
130    image_paths, gt_paths = get_hil_toothseg_paths(path, split, download)
131
132    if resize_inputs:
133        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
134        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
135            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
136        )
137
138    return torch_em.default_segmentation_dataset(
139        raw_paths=image_paths,
140        raw_key=None,
141        label_paths=gt_paths,
142        label_key=None,
143        is_seg_dataset=False,
144        patch_shape=patch_shape,
145        **kwargs
146    )
147
148
149def get_hil_toothseg_loader(
150    path: Union[os.PathLike, str],
151    batch_size: int,
152    patch_shape: Tuple[int, int],
153    split: Literal["train", "val", "test"],
154    resize_inputs: bool = False,
155    download: bool = False,
156    **kwargs
157) -> DataLoader:
158    """Get the HIL ToothSeg dataloader for teeth segmentation.
159
160    Args:
161        path: Filepath to a folder where the data is downloaded for further processing.
162        batch_size: The batch size for training.
163        patch_shape: The patch shape to use for training.
164        split: The data split to use. Either 'train', 'val' or 'test'.
165        resize_inputs: Whether to resize the inputs to the patch shape.
166        download: Whether to download the data if it is not present.
167        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
168
169    Returns:
170        The DataLoader.
171    """
172    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
173    dataset = get_hil_toothseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
174    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://hitl-public-datasets.s3.eu-central-1.amazonaws.com/Teeth+Segmentation.zip'
CHECKSUM = '3b628165a218a5e8d446d1313e6ecbe7cfc599a3d6418cd60b4fb78745becc2e'
def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
30def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
31    """Download the HIL ToothSeg dataset.
32
33    Args:
34        path: Filepath to a folder where the data is downloaded for further processing.
35        download: Whether to download the data if it is not present.
36    """
37    data_dir = os.path.join(path, r"Teeth Segmentation PNG")
38    if os.path.exists(data_dir):
39        return data_dir
40
41    os.makedirs(path, exist_ok=True)
42
43    zip_path = os.path.join(path, "Teeth_Segmentation.zip")
44    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
45    util.unzip(zip_path=zip_path, dst=path)
46
47    return data_dir

Download the HIL ToothSeg dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
def get_hil_toothseg_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 50def get_hil_toothseg_paths(
 51    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 52) -> Tuple[List[str], List[str]]:
 53    """Get paths to the HIL ToothSeg data.
 54
 55    Args:
 56        path: Filepath to a folder where the data is downloaded for further processing.
 57        split: The data split to use. Either 'train', 'val' or 'test'.
 58        download: Whether to download the data if it is not present.
 59
 60    Returns:
 61        List of filepaths for the image data.
 62        List of filepaths for the label data.
 63    """
 64    import cv2 as cv
 65
 66    data_dir = get_hil_toothseg_data(path=path, download=download)
 67
 68    image_paths = natsorted(glob(os.path.join(data_dir, "d2", "img", "*")))
 69    gt_paths = natsorted(glob(os.path.join(data_dir, "d2", "masks_machine", "*")))
 70
 71    neu_gt_dir = os.path.join(data_dir, "preprocessed", "gt")
 72    os.makedirs(neu_gt_dir, exist_ok=True)
 73
 74    neu_gt_paths = []
 75    for gt_path in tqdm(gt_paths, desc="Preprocessing inputs"):
 76        neu_gt_path = os.path.join(neu_gt_dir, f"{Path(gt_path).stem}.tif")
 77        neu_gt_paths.append(neu_gt_path)
 78        if os.path.exists(neu_gt_path):
 79            continue
 80
 81        rgb_gt = cv.imread(gt_path)
 82        rgb_gt = cv.cvtColor(rgb_gt, cv.COLOR_BGR2RGB)
 83        incolors = np.unique(rgb_gt.reshape(-1, rgb_gt.shape[2]), axis=0)
 84
 85        # the first id is always background, let's remove it
 86        if np.array_equal(incolors[0], np.array([0, 0, 0])):
 87            incolors = incolors[1:]
 88
 89        instances = np.zeros(rgb_gt.shape[:2])
 90
 91        color_to_id = {tuple(cvalue): i for i, cvalue in enumerate(incolors, start=1)}
 92        for cvalue, idx in color_to_id.items():
 93            binary_map = (rgb_gt == cvalue).all(axis=2)
 94            instances[binary_map] = idx
 95
 96        imageio.imwrite(neu_gt_path, instances)
 97
 98    if split == "train":
 99        image_paths, neu_gt_paths = image_paths[:450], neu_gt_paths[:450]
100    elif split == "val":
101        image_paths, neu_gt_paths = image_paths[425:475], neu_gt_paths[425:475]
102    elif split == "test":
103        image_paths, neu_gt_paths = image_paths[475:], neu_gt_paths[475:]
104    else:
105        raise ValueError(f"{split} is not a valid split.")
106
107    return image_paths, neu_gt_paths

Get paths to the HIL ToothSeg data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_hil_toothseg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
110def get_hil_toothseg_dataset(
111    path: Union[os.PathLike, str],
112    patch_shape: Tuple[int, int],
113    split: Literal["train", "val", "test"],
114    resize_inputs: bool = False,
115    download: bool = False,
116    **kwargs
117) -> Dataset:
118    """Get the HIL ToothSeg dataset for teeth segmentation.
119
120    Args:
121        path: Filepath to a folder where the data is downloaded for further processing.
122        patch_shape: The patch shape to use for training.
123        split: The data split to use. Either 'train', 'val' or 'test'.
124        resize_inputs: Whether to resize the inputs to the patch shape.
125        download: Whether to download the data if it is not present.
126        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
127
128    Returns:
129        The segmentation dataset.
130    """
131    image_paths, gt_paths = get_hil_toothseg_paths(path, split, download)
132
133    if resize_inputs:
134        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
135        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
136            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
137        )
138
139    return torch_em.default_segmentation_dataset(
140        raw_paths=image_paths,
141        raw_key=None,
142        label_paths=gt_paths,
143        label_key=None,
144        is_seg_dataset=False,
145        patch_shape=patch_shape,
146        **kwargs
147    )

Get the HIL ToothSeg dataset for teeth segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_hil_toothseg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
150def get_hil_toothseg_loader(
151    path: Union[os.PathLike, str],
152    batch_size: int,
153    patch_shape: Tuple[int, int],
154    split: Literal["train", "val", "test"],
155    resize_inputs: bool = False,
156    download: bool = False,
157    **kwargs
158) -> DataLoader:
159    """Get the HIL ToothSeg dataloader for teeth segmentation.
160
161    Args:
162        path: Filepath to a folder where the data is downloaded for further processing.
163        batch_size: The batch size for training.
164        patch_shape: The patch shape to use for training.
165        split: The data split to use. Either 'train', 'val' or 'test'.
166        resize_inputs: Whether to resize the inputs to the patch shape.
167        download: Whether to download the data if it is not present.
168        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
169
170    Returns:
171        The DataLoader.
172    """
173    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
174    dataset = get_hil_toothseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
175    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the HIL ToothSeg dataloader for teeth segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.