torch_em.data.datasets.medical.m2caiseg

The m2caiseg dataset contains annotations for segmentation of organs and instruments in endoscopy.

This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg. The data is from the publication https://doi.org/10.48550/arXiv.2008.10134. Please cite it if you use this data in a publication.

  1"""The m2caiseg dataset contains annotations for segmentation of organs and instruments in endoscopy.
  2
  3This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg.
  4The data is from the publication https://doi.org/10.48550/arXiv.2008.10134.
  5Please cite it if you use this data in a publication.
  6"""
  7
  8import os
  9from glob import glob
 10from tqdm import tqdm
 11from pathlib import Path
 12from natsort import natsorted
 13from typing import Union, Tuple, List, Literal
 14
 15import numpy as np
 16import imageio.v3 as imageio
 17
 18from torch.utils.data import Dataset, DataLoader
 19
 20import torch_em
 21
 22from .. import util
 23
 24
 25LABEL_MAPS = {
 26    (0, 0, 0): 0,  # out of frame
 27    (0, 85, 170): 1,  # grasper
 28    (0, 85, 255): 2,  # bipolar
 29    (0, 170, 255): 3,  # hook
 30    (0, 255, 85): 4,  # scissors
 31    (0, 255, 170): 5,  # clipper
 32    (85, 0, 170): 6,  # irrigator
 33    (85, 0, 255): 7,  # specimen bag
 34    (170, 85, 85): 8,  # trocars
 35    (170, 170, 170): 9,  # clip
 36    (85, 170, 0): 10,  # liver
 37    (85, 170, 255): 11,  # gall bladder
 38    (85, 255, 0): 12,  # fat
 39    (85, 255, 170): 13,  # upper wall
 40    (170, 0, 255): 14,  # artery
 41    (255, 0, 255): 15,  # intestine
 42    (255, 255, 0): 16,  # bile
 43    (255, 0, 0): 17,  # blood
 44    (170, 0, 85): 18,  # unknown
 45}
 46
 47
 48def get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 49    """Get the m2caiseg dataset.
 50
 51    Args:
 52        path: Filepath to a folder where the data is downloaded for further processing.
 53        download: Whether to download the data if it is not present.
 54
 55    Returns:
 56        Filepath where the data is downloaded.
 57    """
 58    data_dir = os.path.join(path, r"m2caiSeg dataset")
 59    if os.path.exists(data_dir):
 60        return data_dir
 61
 62    os.makedirs(path, exist_ok=True)
 63
 64    util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download)
 65    zip_path = os.path.join(path, "m2caiseg.zip")
 66    util.unzip(zip_path=zip_path, dst=path)
 67
 68    return data_dir
 69
 70
 71def get_m2caiseg_paths(
 72    path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False
 73) -> Tuple[List[str], List[str]]:
 74    """Get paths to the m2caiseg data.
 75
 76    Args:
 77        path: Filepath to a folder where the data is downloaded for further processing.
 78        split: The choice of data split.
 79        download: Whether to download the data if it is not present.
 80
 81    Returns:
 82        List of filepaths for the image data.
 83        List of filepaths for the label data.
 84    """
 85    data_dir = get_m2caiseg_data(path=path, download=download)
 86
 87    if split == "val":
 88        impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg")))
 89        gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png")))
 90
 91        imids = [os.path.split(_p)[-1] for _p in impaths]
 92        gids = [os.path.split(_p)[-1] for _p in gpaths]
 93
 94        image_paths = [
 95            _p for _p in natsorted(
 96                glob(os.path.join(data_dir, "trainval", "images", "*.jpg"))
 97            ) if os.path.split(_p)[-1] not in imids
 98        ]
 99        gt_paths = [
100            _p for _p in natsorted(
101                glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png"))
102            ) if os.path.split(_p)[-1] not in gids
103        ]
104
105    else:
106        image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg")))
107        gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png")))
108
109    images_dir = os.path.join(data_dir, "preprocessed", split, "images")
110    mask_dir = os.path.join(data_dir, "preprocessed", split, "masks")
111    if os.path.exists(images_dir) and os.path.exists(mask_dir):
112        return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*")))
113
114    os.makedirs(images_dir, exist_ok=True)
115    os.makedirs(mask_dir, exist_ok=True)
116
117    fimage_paths, fgt_paths = [], []
118    for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)):
119        image = imageio.imread(image_path)
120        gt = imageio.imread(gt_path)
121
122        image_id = Path(image_path).stem
123        gt_id = Path(gt_path).stem
124
125        if image.shape != gt.shape:
126            print("This pair of image and labels mismatch.")
127            continue
128
129        dst_image_path = os.path.join(images_dir, f"{image_id}.tif")
130        dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif")
131
132        fimage_paths.append(image_path)
133        fgt_paths.append(dst_gt_path)
134        if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path):
135            continue
136
137        instances = np.zeros(gt.shape[:2])
138        for lmap in LABEL_MAPS:
139            binary_map = (gt == lmap).all(axis=2)
140            instances[binary_map > 0] = LABEL_MAPS[lmap]
141
142        imageio.imwrite(dst_image_path, image, compression="zlib")
143        imageio.imwrite(dst_gt_path, instances, compression="zlib")
144
145    return fimage_paths, fgt_paths
146
147
148def get_m2caiseg_dataset(
149    path: Union[os.PathLike, str],
150    patch_shape: Tuple[int, int],
151    split: Literal["train", "val", "test"],
152    resize_inputs: bool = False,
153    download: bool = False,
154    **kwargs
155) -> Dataset:
156    """Get the m2caiseg dataset for organ and instrument segmentation.
157
158    Args:
159        path: Filepath to a folder where the data is downloaded for further processing.
160        patch_shape: The patch shape to use for training.
161        split: The choice of data split.
162        resize_inputs: Whether to resize inputs to the desired patch shape.
163        download: Whether to download the data if it is not present.
164        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
165
166    Returns:
167        The segmentation dataset.
168    """
169    image_paths, gt_paths = get_m2caiseg_paths(path, split, download)
170
171    if resize_inputs:
172        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
173        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
174            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
175        )
176
177    return torch_em.default_segmentation_dataset(
178        raw_paths=image_paths,
179        raw_key=None,
180        label_paths=gt_paths,
181        label_key=None,
182        patch_shape=patch_shape,
183        is_seg_dataset=False,
184        **kwargs
185    )
186
187
188def get_m2caiseg_loader(
189    path: Union[os.PathLike, str],
190    batch_size: int,
191    patch_shape: Tuple[int, int],
192    split: Literal["train", "val", "test"],
193    resize_inputs: bool = False,
194    download: bool = False,
195    **kwargs
196) -> DataLoader:
197    """Get the m2caiseg dataloader for organ and instrument segmentation.
198
199    Args:
200        path: Filepath to a folder where the data is downloaded for further processing.
201        batch_size: The batch size for training.
202        patch_shape: The patch shape to use for training.
203        split: The choice of data split.
204        resize_inputs: Whether to resize inputs to the desired patch shape.
205        download: Whether to download the data if it is not present.
206        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
207
208    Returns:
209        The DataLoader.
210    """
211    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
212    dataset = get_m2caiseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
213    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
LABEL_MAPS = {(0, 0, 0): 0, (0, 85, 170): 1, (0, 85, 255): 2, (0, 170, 255): 3, (0, 255, 85): 4, (0, 255, 170): 5, (85, 0, 170): 6, (85, 0, 255): 7, (170, 85, 85): 8, (170, 170, 170): 9, (85, 170, 0): 10, (85, 170, 255): 11, (85, 255, 0): 12, (85, 255, 170): 13, (170, 0, 255): 14, (255, 0, 255): 15, (255, 255, 0): 16, (255, 0, 0): 17, (170, 0, 85): 18}
def get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
49def get_m2caiseg_data(path: Union[os.PathLike, str], download: bool = False) -> str:
50    """Get the m2caiseg dataset.
51
52    Args:
53        path: Filepath to a folder where the data is downloaded for further processing.
54        download: Whether to download the data if it is not present.
55
56    Returns:
57        Filepath where the data is downloaded.
58    """
59    data_dir = os.path.join(path, r"m2caiSeg dataset")
60    if os.path.exists(data_dir):
61        return data_dir
62
63    os.makedirs(path, exist_ok=True)
64
65    util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download)
66    zip_path = os.path.join(path, "m2caiseg.zip")
67    util.unzip(zip_path=zip_path, dst=path)
68
69    return data_dir

Get the m2caiseg dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_m2caiseg_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 72def get_m2caiseg_paths(
 73    path: Union[os.PathLike, str], split: Literal["train", "val", "test"], download: bool = False
 74) -> Tuple[List[str], List[str]]:
 75    """Get paths to the m2caiseg data.
 76
 77    Args:
 78        path: Filepath to a folder where the data is downloaded for further processing.
 79        split: The choice of data split.
 80        download: Whether to download the data if it is not present.
 81
 82    Returns:
 83        List of filepaths for the image data.
 84        List of filepaths for the label data.
 85    """
 86    data_dir = get_m2caiseg_data(path=path, download=download)
 87
 88    if split == "val":
 89        impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg")))
 90        gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png")))
 91
 92        imids = [os.path.split(_p)[-1] for _p in impaths]
 93        gids = [os.path.split(_p)[-1] for _p in gpaths]
 94
 95        image_paths = [
 96            _p for _p in natsorted(
 97                glob(os.path.join(data_dir, "trainval", "images", "*.jpg"))
 98            ) if os.path.split(_p)[-1] not in imids
 99        ]
100        gt_paths = [
101            _p for _p in natsorted(
102                glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png"))
103            ) if os.path.split(_p)[-1] not in gids
104        ]
105
106    else:
107        image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg")))
108        gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png")))
109
110    images_dir = os.path.join(data_dir, "preprocessed", split, "images")
111    mask_dir = os.path.join(data_dir, "preprocessed", split, "masks")
112    if os.path.exists(images_dir) and os.path.exists(mask_dir):
113        return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*")))
114
115    os.makedirs(images_dir, exist_ok=True)
116    os.makedirs(mask_dir, exist_ok=True)
117
118    fimage_paths, fgt_paths = [], []
119    for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)):
120        image = imageio.imread(image_path)
121        gt = imageio.imread(gt_path)
122
123        image_id = Path(image_path).stem
124        gt_id = Path(gt_path).stem
125
126        if image.shape != gt.shape:
127            print("This pair of image and labels mismatch.")
128            continue
129
130        dst_image_path = os.path.join(images_dir, f"{image_id}.tif")
131        dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif")
132
133        fimage_paths.append(image_path)
134        fgt_paths.append(dst_gt_path)
135        if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path):
136            continue
137
138        instances = np.zeros(gt.shape[:2])
139        for lmap in LABEL_MAPS:
140            binary_map = (gt == lmap).all(axis=2)
141            instances[binary_map > 0] = LABEL_MAPS[lmap]
142
143        imageio.imwrite(dst_image_path, image, compression="zlib")
144        imageio.imwrite(dst_gt_path, instances, compression="zlib")
145
146    return fimage_paths, fgt_paths

Get paths to the m2caiseg data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_m2caiseg_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
149def get_m2caiseg_dataset(
150    path: Union[os.PathLike, str],
151    patch_shape: Tuple[int, int],
152    split: Literal["train", "val", "test"],
153    resize_inputs: bool = False,
154    download: bool = False,
155    **kwargs
156) -> Dataset:
157    """Get the m2caiseg dataset for organ and instrument segmentation.
158
159    Args:
160        path: Filepath to a folder where the data is downloaded for further processing.
161        patch_shape: The patch shape to use for training.
162        split: The choice of data split.
163        resize_inputs: Whether to resize inputs to the desired patch shape.
164        download: Whether to download the data if it is not present.
165        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
166
167    Returns:
168        The segmentation dataset.
169    """
170    image_paths, gt_paths = get_m2caiseg_paths(path, split, download)
171
172    if resize_inputs:
173        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
174        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
175            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
176        )
177
178    return torch_em.default_segmentation_dataset(
179        raw_paths=image_paths,
180        raw_key=None,
181        label_paths=gt_paths,
182        label_key=None,
183        patch_shape=patch_shape,
184        is_seg_dataset=False,
185        **kwargs
186    )

Get the m2caiseg dataset for organ and instrument segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_m2caiseg_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
189def get_m2caiseg_loader(
190    path: Union[os.PathLike, str],
191    batch_size: int,
192    patch_shape: Tuple[int, int],
193    split: Literal["train", "val", "test"],
194    resize_inputs: bool = False,
195    download: bool = False,
196    **kwargs
197) -> DataLoader:
198    """Get the m2caiseg dataloader for organ and instrument segmentation.
199
200    Args:
201        path: Filepath to a folder where the data is downloaded for further processing.
202        batch_size: The batch size for training.
203        patch_shape: The patch shape to use for training.
204        split: The choice of data split.
205        resize_inputs: Whether to resize inputs to the desired patch shape.
206        download: Whether to download the data if it is not present.
207        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
208
209    Returns:
210        The DataLoader.
211    """
212    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
213    dataset = get_m2caiseg_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
214    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the m2caiseg dataloader for organ and instrument segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.