torch_em.data.datasets.light_microscopy.deepbacs

DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.

This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.

  1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
  2
  3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z.
  4Please cite it if you use this dataset in your research.
  5"""
  6
  7import os
  8import shutil
  9import numpy as np
 10from glob import glob
 11from typing import Tuple, Union
 12
 13import torch_em
 14from torch.utils.data import Dataset, DataLoader
 15from .. import util
 16
 17URLS = {
 18    "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1",
 19    "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1",
 20    "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1",
 21    "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1",
 22}
 23CHECKSUMS = {
 24    "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307",
 25    "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17",
 26    "b_subtilis": "1",
 27    "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0",
 28}
 29
 30
 31def _assort_val_set(path, bac_type):
 32    image_paths = glob(os.path.join(path, bac_type, "training", "source", "*"))
 33    image_paths = [os.path.split(_path)[-1] for _path in image_paths]
 34
 35    val_partition = 0.2
 36    # let's get a balanced set of bacterias, if bac_type is mixed
 37    if bac_type == "mixed":
 38        _sort_1, _sort_2, _sort_3 = [], [], []
 39        for _path in image_paths:
 40            if _path.startswith("JE2"):
 41                _sort_1.append(_path)
 42            elif _path.startswith("pos"):
 43                _sort_2.append(_path)
 44            elif _path.startswith("train_"):
 45                _sort_3.append(_path)
 46
 47        val_image_paths = [
 48            *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False),
 49            *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False),
 50            *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False)
 51        ]
 52    else:
 53        val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False)
 54
 55    val_image_dir = os.path.join(path, bac_type, "val", "source")
 56    val_label_dir = os.path.join(path, bac_type, "val", "target")
 57    os.makedirs(val_image_dir, exist_ok=True)
 58    os.makedirs(val_label_dir, exist_ok=True)
 59
 60    for sample_id in val_image_paths:
 61        src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id)
 62        dst_val_image_path = os.path.join(val_image_dir, sample_id)
 63        shutil.move(src_val_image_path, dst_val_image_path)
 64
 65        src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id)
 66        dst_val_label_path = os.path.join(val_label_dir, sample_id)
 67        shutil.move(src_val_label_path, dst_val_label_path)
 68
 69
 70def get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 71    f"""Download the DeepBacs training data.
 72
 73    Args:
 74        path: Filepath to a folder where the downloaded data will be saved.
 75        bac_type: The bacteria type. The available types are:
 76            {', '.join(URLS.keys())}
 77        download: Whether to download the data if it is not present.
 78
 79    Returns:
 80        The filepath to the training data.
 81    """
 82    bac_types = list(URLS.keys())
 83    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
 84
 85    data_folder = os.path.join(path, bac_type)
 86    if os.path.exists(data_folder):
 87        return data_folder
 88
 89    os.makedirs(path, exist_ok=True)
 90    zip_path = os.path.join(path, f"{bac_type}.zip")
 91    if not os.path.exists(zip_path):
 92        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
 93    util.unzip(zip_path, os.path.join(path, bac_type))
 94
 95    # Get a val split for the expected bacteria type.
 96    _assort_val_set(path, bac_type)
 97    return data_folder
 98
 99
100def _get_paths(path, bac_type, split):
101    # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
102    # mixed is the combination of all other types
103    if split == "train":
104        dir_choice = "training"
105    else:
106        dir_choice = split
107
108    if bac_type != "mixed":
109        raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}")
110    image_folder = os.path.join(path, bac_type, dir_choice, "source")
111    label_folder = os.path.join(path, bac_type, dir_choice, "target")
112    return image_folder, label_folder
113
114
115def get_deepbacs_dataset(
116    path: Union[os.PathLike, str],
117    split: str,
118    patch_shape: Tuple[int, int],
119    bac_type: str = "mixed",
120    download: bool = False,
121    **kwargs
122) -> Dataset:
123    f"""Get the CTC dataset for cell segmentation.
124
125    Args:
126        path: Filepath to a folder where the downloaded data will be saved.
127        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
128        patch_shape: The patch shape to use for training.
129        bac_type: The bacteria type. The available types are:
130            {', '.join(URLS.keys())}
131        download: Whether to download the data if it is not present.
132        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
133
134    Returns:
135       The segmentation dataset.
136    """
137    assert split in ("train", "val", "test")
138    get_deebacs_data(path, bac_type, download)
139    image_folder, label_folder = _get_paths(path, bac_type, split)
140    dataset = torch_em.default_segmentation_dataset(
141        image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs
142    )
143    return dataset
144
145
146def get_deepbacs_loader(
147    path: Union[os.PathLike, str],
148    split: str,
149    patch_shape: Tuple[int, int],
150    batch_size: int,
151    bac_type: str = "mixed",
152    download: bool = False,
153    **kwargs
154) -> DataLoader:
155    f"""Get the CTC dataset for cell segmentation.
156
157    Args:
158        path: Filepath to a folder where the downloaded data will be saved.
159        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
160        patch_shape: The patch shape to use for training.
161        batch_size: The batch size for training.
162        bac_type: The bacteria type. The available types are:
163            {', '.join(URLS.keys())}
164        download: Whether to download the data if it is not present.
165        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
166
167    Returns:
168        The DataLoader.
169    """
170    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
171    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
172    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
173    return loader
URLS = {'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS = {'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
71def get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
72    f"""Download the DeepBacs training data.
73
74    Args:
75        path: Filepath to a folder where the downloaded data will be saved.
76        bac_type: The bacteria type. The available types are:
77            {', '.join(URLS.keys())}
78        download: Whether to download the data if it is not present.
79
80    Returns:
81        The filepath to the training data.
82    """
83    bac_types = list(URLS.keys())
84    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
85
86    data_folder = os.path.join(path, bac_type)
87    if os.path.exists(data_folder):
88        return data_folder
89
90    os.makedirs(path, exist_ok=True)
91    zip_path = os.path.join(path, f"{bac_type}.zip")
92    if not os.path.exists(zip_path):
93        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
94    util.unzip(zip_path, os.path.join(path, bac_type))
95
96    # Get a val split for the expected bacteria type.
97    _assort_val_set(path, bac_type)
98    return data_folder
def get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
116def get_deepbacs_dataset(
117    path: Union[os.PathLike, str],
118    split: str,
119    patch_shape: Tuple[int, int],
120    bac_type: str = "mixed",
121    download: bool = False,
122    **kwargs
123) -> Dataset:
124    f"""Get the CTC dataset for cell segmentation.
125
126    Args:
127        path: Filepath to a folder where the downloaded data will be saved.
128        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
129        patch_shape: The patch shape to use for training.
130        bac_type: The bacteria type. The available types are:
131            {', '.join(URLS.keys())}
132        download: Whether to download the data if it is not present.
133        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
134
135    Returns:
136       The segmentation dataset.
137    """
138    assert split in ("train", "val", "test")
139    get_deebacs_data(path, bac_type, download)
140    image_folder, label_folder = _get_paths(path, bac_type, split)
141    dataset = torch_em.default_segmentation_dataset(
142        image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs
143    )
144    return dataset
def get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
147def get_deepbacs_loader(
148    path: Union[os.PathLike, str],
149    split: str,
150    patch_shape: Tuple[int, int],
151    batch_size: int,
152    bac_type: str = "mixed",
153    download: bool = False,
154    **kwargs
155) -> DataLoader:
156    f"""Get the CTC dataset for cell segmentation.
157
158    Args:
159        path: Filepath to a folder where the downloaded data will be saved.
160        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
161        patch_shape: The patch shape to use for training.
162        batch_size: The batch size for training.
163        bac_type: The bacteria type. The available types are:
164            {', '.join(URLS.keys())}
165        download: Whether to download the data if it is not present.
166        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
167
168    Returns:
169        The DataLoader.
170    """
171    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
172    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
173    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
174    return loader