torch_em.data.datasets.medical.hil_toothseg
The HIL ToothSeg dataset contains annotations for teeth segmentation in panoramic dental radiographs.
This dataset is from the publication https://www.mdpi.com/1424-8220/21/9/3110. Please cite it if you use this dataset for your research.
1"""The HIL ToothSeg dataset contains annotations for teeth segmentation 2in panoramic dental radiographs. 3 4This dataset is from the publication https://www.mdpi.com/1424-8220/21/9/3110. 5Please cite it if you use this dataset for your research. 6""" 7 8import os 9from glob import glob 10from tqdm import tqdm 11from pathlib import Path 12from natsort import natsorted 13from typing import Union, Literal, Tuple, List 14 15import numpy as np 16import imageio.v3 as imageio 17 18from torch.utils.data import Dataset, DataLoader 19 20import torch_em 21 22from .. import util 23 24 25URL = "https://hitl-public-datasets.s3.eu-central-1.amazonaws.com/Teeth+Segmentation.zip" 26CHECKSUM = "3b628165a218a5e8d446d1313e6ecbe7cfc599a3d6418cd60b4fb78745becc2e" 27 28 29def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 30 """Download the HIL ToothSeg dataset. 31 32 Args: 33 path: Filepath to a folder where the data is downloaded for further processing. 34 download: Whether to download the data if it is not present. 35 """ 36 data_dir = os.path.join(path, r"Teeth Segmentation PNG") 37 if os.path.exists(data_dir): 38 return data_dir 39 40 os.makedirs(path, exist_ok=True) 41 42 zip_path = os.path.join(path, "Teeth_Segmentation.zip") 43 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 44 util.unzip(zip_path=zip_path, dst=path) 45 46 return data_dir 47 48 49def get_hil_toothseg_paths( 50 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 51) -> Tuple[List[str], List[str]]: 52 """Get paths to the HIL ToothSeg data. 53 54 Args: 55 path: Filepath to a folder where the data is downloaded for further processing. 56 split: The data split to use. Either 'train', 'val' or 'test'. 57 download: Whether to download the data if it is not present. 58 59 Returns: 60 List of filepaths for the image data. 61 List of filepaths for the label data. 62 """ 63 import cv2 as cv 64 65 data_dir = get_hil_toothseg_data(path=path, download=download) 66 67 image_paths = natsorted(glob(os.path.join(data_dir, "d2", "img", "*"))) 68 gt_paths = natsorted(glob(os.path.join(data_dir, "d2", "masks_machine", "*"))) 69 70 neu_gt_dir = os.path.join(data_dir, "preprocessed", "gt") 71 os.makedirs(neu_gt_dir, exist_ok=True) 72 73 neu_gt_paths = [] 74 for gt_path in tqdm(gt_paths, desc="Preprocessing inputs"): 75 neu_gt_path = os.path.join(neu_gt_dir, f"{Path(gt_path).stem}.tif") 76 neu_gt_paths.append(neu_gt_path) 77 if os.path.exists(neu_gt_path): 78 continue 79 80 rgb_gt = cv.imread(gt_path) 81 rgb_gt = cv.cvtColor(rgb_gt, cv.COLOR_BGR2RGB) 82 incolors = np.unique(rgb_gt.reshape(-1, rgb_gt.shape[2]), axis=0) 83 84 # the first id is always background, let's remove it 85 if np.array_equal(incolors[0], np.array([0, 0, 0])): 86 incolors = incolors[1:] 87 88 instances = np.zeros(rgb_gt.shape[:2]) 89 90 color_to_id = {tuple(cvalue): i for i, cvalue in enumerate(incolors, start=1)} 91 for cvalue, idx in color_to_id.items(): 92 binary_map = (rgb_gt == cvalue).all(axis=2) 93 instances[binary_map] = idx 94 95 imageio.imwrite(neu_gt_path, instances) 96 97 if split == "train": 98 image_paths, neu_gt_paths = image_paths[:450], neu_gt_paths[:450] 99 elif split == "val": 100 image_paths, neu_gt_paths = image_paths[425:475], neu_gt_paths[425:475] 101 elif split == "test": 102 image_paths, neu_gt_paths = image_paths[475:], neu_gt_paths[475:] 103 else: 104 raise ValueError(f"{split} is not a valid split.") 105 106 return image_paths, neu_gt_paths 107 108 109def get_hil_toothseg_dataset( 110 path: Union[os.PathLike, str], 111 patch_shape: Tuple[int, int], 112 split: Literal["train", "val", "test"], 113 resize_inputs: bool = False, 114 download: bool = False, 115 **kwargs 116) -> Dataset: 117 """Get the HIL ToothSeg dataset for teeth segmentation. 118 119 Args: 120 path: Filepath to a folder where the data is downloaded for further processing. 121 patch_shape: The patch shape to use for training. 122 split: The data split to use. Either 'train', 'val' or 'test'. 123 resize_inputs: Whether to resize the inputs to the patch shape. 124 download: Whether to download the data if it is not present. 125 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 126 127 Returns: 128 The segmentation dataset. 129 """ 130 image_paths, gt_paths = get_hil_toothseg_paths(path, split, download) 131 132 if resize_inputs: 133 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 134 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 135 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 136 ) 137 138 return torch_em.default_segmentation_dataset( 139 raw_paths=image_paths, 140 raw_key=None, 141 label_paths=gt_paths, 142 label_key=None, 143 is_seg_dataset=False, 144 patch_shape=patch_shape, 145 **kwargs 146 ) 147 148 149def get_hil_toothseg_loader( 150 path: Union[os.PathLike, str], 151 batch_size: int, 152 patch_shape: Tuple[int, int], 153 split: Literal["train", "val", "test"], 154 resize_inputs: bool = False, 155 download: bool = False, 156 **kwargs 157) -> DataLoader: 158 """Get the HIL ToothSeg dataloader for teeth segmentation. 159 160 Args: 161 path: Filepath to a folder where the data is downloaded for further processing. 162 batch_size: The batch size for training. 163 patch_shape: The patch shape to use for training. 164 split: The data split to use. Either 'train', 'val' or 'test'. 165 resize_inputs: Whether to resize the inputs to the patch shape. 166 download: Whether to download the data if it is not present. 167 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 168 169 Returns: 170 The DataLoader. 171 """ 172 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 173 dataset = get_hil_toothseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 174 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://hitl-public-datasets.s3.eu-central-1.amazonaws.com/Teeth+Segmentation.zip'
CHECKSUM =
'3b628165a218a5e8d446d1313e6ecbe7cfc599a3d6418cd60b4fb78745becc2e'
def
get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
30def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 31 """Download the HIL ToothSeg dataset. 32 33 Args: 34 path: Filepath to a folder where the data is downloaded for further processing. 35 download: Whether to download the data if it is not present. 36 """ 37 data_dir = os.path.join(path, r"Teeth Segmentation PNG") 38 if os.path.exists(data_dir): 39 return data_dir 40 41 os.makedirs(path, exist_ok=True) 42 43 zip_path = os.path.join(path, "Teeth_Segmentation.zip") 44 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 45 util.unzip(zip_path=zip_path, dst=path) 46 47 return data_dir
Download the HIL ToothSeg dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
def
get_hil_toothseg_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
50def get_hil_toothseg_paths( 51 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 52) -> Tuple[List[str], List[str]]: 53 """Get paths to the HIL ToothSeg data. 54 55 Args: 56 path: Filepath to a folder where the data is downloaded for further processing. 57 split: The data split to use. Either 'train', 'val' or 'test'. 58 download: Whether to download the data if it is not present. 59 60 Returns: 61 List of filepaths for the image data. 62 List of filepaths for the label data. 63 """ 64 import cv2 as cv 65 66 data_dir = get_hil_toothseg_data(path=path, download=download) 67 68 image_paths = natsorted(glob(os.path.join(data_dir, "d2", "img", "*"))) 69 gt_paths = natsorted(glob(os.path.join(data_dir, "d2", "masks_machine", "*"))) 70 71 neu_gt_dir = os.path.join(data_dir, "preprocessed", "gt") 72 os.makedirs(neu_gt_dir, exist_ok=True) 73 74 neu_gt_paths = [] 75 for gt_path in tqdm(gt_paths, desc="Preprocessing inputs"): 76 neu_gt_path = os.path.join(neu_gt_dir, f"{Path(gt_path).stem}.tif") 77 neu_gt_paths.append(neu_gt_path) 78 if os.path.exists(neu_gt_path): 79 continue 80 81 rgb_gt = cv.imread(gt_path) 82 rgb_gt = cv.cvtColor(rgb_gt, cv.COLOR_BGR2RGB) 83 incolors = np.unique(rgb_gt.reshape(-1, rgb_gt.shape[2]), axis=0) 84 85 # the first id is always background, let's remove it 86 if np.array_equal(incolors[0], np.array([0, 0, 0])): 87 incolors = incolors[1:] 88 89 instances = np.zeros(rgb_gt.shape[:2]) 90 91 color_to_id = {tuple(cvalue): i for i, cvalue in enumerate(incolors, start=1)} 92 for cvalue, idx in color_to_id.items(): 93 binary_map = (rgb_gt == cvalue).all(axis=2) 94 instances[binary_map] = idx 95 96 imageio.imwrite(neu_gt_path, instances) 97 98 if split == "train": 99 image_paths, neu_gt_paths = image_paths[:450], neu_gt_paths[:450] 100 elif split == "val": 101 image_paths, neu_gt_paths = image_paths[425:475], neu_gt_paths[425:475] 102 elif split == "test": 103 image_paths, neu_gt_paths = image_paths[475:], neu_gt_paths[475:] 104 else: 105 raise ValueError(f"{split} is not a valid split.") 106 107 return image_paths, neu_gt_paths
Get paths to the HIL ToothSeg data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The data split to use. Either 'train', 'val' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_hil_toothseg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
110def get_hil_toothseg_dataset( 111 path: Union[os.PathLike, str], 112 patch_shape: Tuple[int, int], 113 split: Literal["train", "val", "test"], 114 resize_inputs: bool = False, 115 download: bool = False, 116 **kwargs 117) -> Dataset: 118 """Get the HIL ToothSeg dataset for teeth segmentation. 119 120 Args: 121 path: Filepath to a folder where the data is downloaded for further processing. 122 patch_shape: The patch shape to use for training. 123 split: The data split to use. Either 'train', 'val' or 'test'. 124 resize_inputs: Whether to resize the inputs to the patch shape. 125 download: Whether to download the data if it is not present. 126 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 127 128 Returns: 129 The segmentation dataset. 130 """ 131 image_paths, gt_paths = get_hil_toothseg_paths(path, split, download) 132 133 if resize_inputs: 134 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 135 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 136 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 137 ) 138 139 return torch_em.default_segmentation_dataset( 140 raw_paths=image_paths, 141 raw_key=None, 142 label_paths=gt_paths, 143 label_key=None, 144 is_seg_dataset=False, 145 patch_shape=patch_shape, 146 **kwargs 147 )
Get the HIL ToothSeg dataset for teeth segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The data split to use. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize the inputs to the patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_hil_toothseg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
150def get_hil_toothseg_loader( 151 path: Union[os.PathLike, str], 152 batch_size: int, 153 patch_shape: Tuple[int, int], 154 split: Literal["train", "val", "test"], 155 resize_inputs: bool = False, 156 download: bool = False, 157 **kwargs 158) -> DataLoader: 159 """Get the HIL ToothSeg dataloader for teeth segmentation. 160 161 Args: 162 path: Filepath to a folder where the data is downloaded for further processing. 163 batch_size: The batch size for training. 164 patch_shape: The patch shape to use for training. 165 split: The data split to use. Either 'train', 'val' or 'test'. 166 resize_inputs: Whether to resize the inputs to the patch shape. 167 download: Whether to download the data if it is not present. 168 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 169 170 Returns: 171 The DataLoader. 172 """ 173 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 174 dataset = get_hil_toothseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 175 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the HIL ToothSeg dataloader for teeth segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The data split to use. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize the inputs to the patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.