torch_em.data.datasets.medical.m2caiseg
The m2caiseg dataset contains annotations for segmentation of organs and instruments in endoscopy.
This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg. The data is from the publication https://doi.org/10.48550/arXiv.2008.10134. Please cite it if you use this data in a publication.
1"""The m2caiseg dataset contains annotations for segmentation of organs and instruments in endoscopy. 2 3This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg. 4The data is from the publication https://doi.org/10.48550/arXiv.2008.10134. 5Please cite it if you use this data in a publication. 6""" 7 8import os 9from glob import glob 10from tqdm import tqdm 11from pathlib import Path 12from natsort import natsorted 13from typing import Union, Tuple, List, Literal 14 15import numpy as np 16import imageio.v3 as imageio 17 18from torch.utils.data import Dataset, DataLoader 19 20import torch_em 21 22from .. import util 23 24 25LABEL_MAPS = { 26 (0, 0, 0): 0, # out of frame 27 (0, 85, 170): 1, # grasper 28 (0, 85, 255): 2, # bipolar 29 (0, 170, 255): 3, # hook 30 (0, 255, 85): 4, # scissors 31 (0, 255, 170): 5, # clipper 32 (85, 0, 170): 6, # irrigator 33 (85, 0, 255): 7, # specimen bag 34 (170, 85, 85): 8, # trocars 35 (170, 170, 170): 9, # clip 36 (85, 170, 0): 10, # liver 37 (85, 170, 255): 11, # gall bladder 38 (85, 255, 0): 12, # fat 39 (85, 255, 170): 13, # upper wall 40 (170, 0, 255): 14, # artery 41 (255, 0, 255): 15, # intestine 42 (255, 255, 0): 16, # bile 43 (255, 0, 0): 17, # blood 44 (170, 0, 85): 18, # unknown 45} 46 47 48def get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 49 """Get the m2caiseg dataset. 50 51 Args: 52 path: Filepath to a folder where the data is downloaded for further processing. 53 download: Whether to download the data if it is not present. 54 55 Returns: 56 Filepath where the data is downloaded. 57 """ 58 data_dir = os.path.join(path, r"m2caiSeg dataset") 59 if os.path.exists(data_dir): 60 return data_dir 61 62 os.makedirs(path, exist_ok=True) 63 64 util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download) 65 zip_path = os.path.join(path, "m2caiseg.zip") 66 util.unzip(zip_path=zip_path, dst=path) 67 68 return data_dir 69 70 71def get_m2caiseg_paths( 72 path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False 73) -> Tuple[List[str], List[str]]: 74 """Get paths to the m2caiseg data. 75 76 Args: 77 path: Filepath to a folder where the data is downloaded for further processing. 78 split: The choice of data split. 79 download: Whether to download the data if it is not present. 80 81 Returns: 82 List of filepaths for the image data. 83 List of filepaths for the label data. 84 """ 85 data_dir = get_m2caiseg_data(path=path, download=download) 86 87 if split == "val": 88 impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg"))) 89 gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png"))) 90 91 imids = [os.path.split(_p)[-1] for _p in impaths] 92 gids = [os.path.split(_p)[-1] for _p in gpaths] 93 94 image_paths = [ 95 _p for _p in natsorted( 96 glob(os.path.join(data_dir, "trainval", "images", "*.jpg")) 97 ) if os.path.split(_p)[-1] not in imids 98 ] 99 gt_paths = [ 100 _p for _p in natsorted( 101 glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png")) 102 ) if os.path.split(_p)[-1] not in gids 103 ] 104 105 else: 106 image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg"))) 107 gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png"))) 108 109 images_dir = os.path.join(data_dir, "preprocessed", split, "images") 110 mask_dir = os.path.join(data_dir, "preprocessed", split, "masks") 111 if os.path.exists(images_dir) and os.path.exists(mask_dir): 112 return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*"))) 113 114 os.makedirs(images_dir, exist_ok=True) 115 os.makedirs(mask_dir, exist_ok=True) 116 117 fimage_paths, fgt_paths = [], [] 118 for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): 119 image = imageio.imread(image_path) 120 gt = imageio.imread(gt_path) 121 122 image_id = Path(image_path).stem 123 gt_id = Path(gt_path).stem 124 125 if image.shape != gt.shape: 126 print("This pair of image and labels mismatch.") 127 continue 128 129 dst_image_path = os.path.join(images_dir, f"{image_id}.tif") 130 dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif") 131 132 fimage_paths.append(image_path) 133 fgt_paths.append(dst_gt_path) 134 if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path): 135 continue 136 137 instances = np.zeros(gt.shape[:2]) 138 for lmap in LABEL_MAPS: 139 binary_map = (gt == lmap).all(axis=2) 140 instances[binary_map > 0] = LABEL_MAPS[lmap] 141 142 imageio.imwrite(dst_image_path, image, compression="zlib") 143 imageio.imwrite(dst_gt_path, instances, compression="zlib") 144 145 return fimage_paths, fgt_paths 146 147 148def get_m2caiseg_dataset( 149 path: Union[os.PathLike, str], 150 patch_shape: Tuple[int, int], 151 split: Literal["train", "val", "test"], 152 resize_inputs: bool = False, 153 download: bool = False, 154 **kwargs 155) -> Dataset: 156 """Get the m2caiseg dataset for organ and instrument segmentation. 157 158 Args: 159 path: Filepath to a folder where the data is downloaded for further processing. 160 patch_shape: The patch shape to use for training. 161 split: The choice of data split. 162 resize_inputs: Whether to resize inputs to the desired patch shape. 163 download: Whether to download the data if it is not present. 164 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 165 166 Returns: 167 The segmentation dataset. 168 """ 169 image_paths, gt_paths = get_m2caiseg_paths(path, split, download) 170 171 if resize_inputs: 172 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 173 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 174 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 175 ) 176 177 return torch_em.default_segmentation_dataset( 178 raw_paths=image_paths, 179 raw_key=None, 180 label_paths=gt_paths, 181 label_key=None, 182 patch_shape=patch_shape, 183 is_seg_dataset=False, 184 **kwargs 185 ) 186 187 188def get_m2caiseg_loader( 189 path: Union[os.PathLike, str], 190 batch_size: int, 191 patch_shape: Tuple[int, int], 192 split: Literal["train", "val", "test"], 193 resize_inputs: bool = False, 194 download: bool = False, 195 **kwargs 196) -> DataLoader: 197 """Get the m2caiseg dataloader for organ and instrument segmentation. 198 199 Args: 200 path: Filepath to a folder where the data is downloaded for further processing. 201 batch_size: The batch size for training. 202 patch_shape: The patch shape to use for training. 203 split: The choice of data split. 204 resize_inputs: Whether to resize inputs to the desired patch shape. 205 download: Whether to download the data if it is not present. 206 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 207 208 Returns: 209 The DataLoader. 210 """ 211 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 212 dataset = get_m2caiseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 213 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
LABEL_MAPS =
{(0, 0, 0): 0, (0, 85, 170): 1, (0, 85, 255): 2, (0, 170, 255): 3, (0, 255, 85): 4, (0, 255, 170): 5, (85, 0, 170): 6, (85, 0, 255): 7, (170, 85, 85): 8, (170, 170, 170): 9, (85, 170, 0): 10, (85, 170, 255): 11, (85, 255, 0): 12, (85, 255, 170): 13, (170, 0, 255): 14, (255, 0, 255): 15, (255, 255, 0): 16, (255, 0, 0): 17, (170, 0, 85): 18}
def
get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
49def get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str: 50 """Get the m2caiseg dataset. 51 52 Args: 53 path: Filepath to a folder where the data is downloaded for further processing. 54 download: Whether to download the data if it is not present. 55 56 Returns: 57 Filepath where the data is downloaded. 58 """ 59 data_dir = os.path.join(path, r"m2caiSeg dataset") 60 if os.path.exists(data_dir): 61 return data_dir 62 63 os.makedirs(path, exist_ok=True) 64 65 util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download) 66 zip_path = os.path.join(path, "m2caiseg.zip") 67 util.unzip(zip_path=zip_path, dst=path) 68 69 return data_dir
Get the m2caiseg dataset.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
Returns:
Filepath where the data is downloaded.
def
get_m2caiseg_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
72def get_m2caiseg_paths( 73 path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False 74) -> Tuple[List[str], List[str]]: 75 """Get paths to the m2caiseg data. 76 77 Args: 78 path: Filepath to a folder where the data is downloaded for further processing. 79 split: The choice of data split. 80 download: Whether to download the data if it is not present. 81 82 Returns: 83 List of filepaths for the image data. 84 List of filepaths for the label data. 85 """ 86 data_dir = get_m2caiseg_data(path=path, download=download) 87 88 if split == "val": 89 impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg"))) 90 gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png"))) 91 92 imids = [os.path.split(_p)[-1] for _p in impaths] 93 gids = [os.path.split(_p)[-1] for _p in gpaths] 94 95 image_paths = [ 96 _p for _p in natsorted( 97 glob(os.path.join(data_dir, "trainval", "images", "*.jpg")) 98 ) if os.path.split(_p)[-1] not in imids 99 ] 100 gt_paths = [ 101 _p for _p in natsorted( 102 glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png")) 103 ) if os.path.split(_p)[-1] not in gids 104 ] 105 106 else: 107 image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg"))) 108 gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png"))) 109 110 images_dir = os.path.join(data_dir, "preprocessed", split, "images") 111 mask_dir = os.path.join(data_dir, "preprocessed", split, "masks") 112 if os.path.exists(images_dir) and os.path.exists(mask_dir): 113 return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*"))) 114 115 os.makedirs(images_dir, exist_ok=True) 116 os.makedirs(mask_dir, exist_ok=True) 117 118 fimage_paths, fgt_paths = [], [] 119 for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): 120 image = imageio.imread(image_path) 121 gt = imageio.imread(gt_path) 122 123 image_id = Path(image_path).stem 124 gt_id = Path(gt_path).stem 125 126 if image.shape != gt.shape: 127 print("This pair of image and labels mismatch.") 128 continue 129 130 dst_image_path = os.path.join(images_dir, f"{image_id}.tif") 131 dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif") 132 133 fimage_paths.append(image_path) 134 fgt_paths.append(dst_gt_path) 135 if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path): 136 continue 137 138 instances = np.zeros(gt.shape[:2]) 139 for lmap in LABEL_MAPS: 140 binary_map = (gt == lmap).all(axis=2) 141 instances[binary_map > 0] = LABEL_MAPS[lmap] 142 143 imageio.imwrite(dst_image_path, image, compression="zlib") 144 imageio.imwrite(dst_gt_path, instances, compression="zlib") 145 146 return fimage_paths, fgt_paths
Get paths to the m2caiseg data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_m2caiseg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
149def get_m2caiseg_dataset( 150 path: Union[os.PathLike, str], 151 patch_shape: Tuple[int, int], 152 split: Literal["train", "val", "test"], 153 resize_inputs: bool = False, 154 download: bool = False, 155 **kwargs 156) -> Dataset: 157 """Get the m2caiseg dataset for organ and instrument segmentation. 158 159 Args: 160 path: Filepath to a folder where the data is downloaded for further processing. 161 patch_shape: The patch shape to use for training. 162 split: The choice of data split. 163 resize_inputs: Whether to resize inputs to the desired patch shape. 164 download: Whether to download the data if it is not present. 165 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 166 167 Returns: 168 The segmentation dataset. 169 """ 170 image_paths, gt_paths = get_m2caiseg_paths(path, split, download) 171 172 if resize_inputs: 173 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 174 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 175 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 176 ) 177 178 return torch_em.default_segmentation_dataset( 179 raw_paths=image_paths, 180 raw_key=None, 181 label_paths=gt_paths, 182 label_key=None, 183 patch_shape=patch_shape, 184 is_seg_dataset=False, 185 **kwargs 186 )
Get the m2caiseg dataset for organ and instrument segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_m2caiseg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
189def get_m2caiseg_loader( 190 path: Union[os.PathLike, str], 191 batch_size: int, 192 patch_shape: Tuple[int, int], 193 split: Literal["train", "val", "test"], 194 resize_inputs: bool = False, 195 download: bool = False, 196 **kwargs 197) -> DataLoader: 198 """Get the m2caiseg dataloader for organ and instrument segmentation. 199 200 Args: 201 path: Filepath to a folder where the data is downloaded for further processing. 202 batch_size: The batch size for training. 203 patch_shape: The patch shape to use for training. 204 split: The choice of data split. 205 resize_inputs: Whether to resize inputs to the desired patch shape. 206 download: Whether to download the data if it is not present. 207 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 208 209 Returns: 210 The DataLoader. 211 """ 212 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 213 dataset = get_m2caiseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 214 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the m2caiseg dataloader for organ and instrument segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.