torch_em.data.datasets.light_microscopy.deepbacs

DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.

This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.

  1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
  2
  3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z.
  4Please cite it if you use this dataset in your research.
  5"""
  6
  7import os
  8import shutil
  9from glob import glob
 10from typing import Tuple, Union
 11
 12import numpy as np
 13
 14from torch.utils.data import Dataset, DataLoader
 15
 16import torch_em
 17
 18from .. import util
 19
 20URLS = {
 21    "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1",  # noqa
 22    "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1",  # noqa
 23    "e_coli_stationary": "https://zenodo.org/records/6400327/files/DeepBacs_Data_Segmentation_Ecoli_stationary_phase.zip?download=1",  # noqa
 24    "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1",
 25    "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1",
 26}
 27CHECKSUMS = {
 28    "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307",
 29    "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17",
 30    "e_coli_stationary": None,
 31    "b_subtilis": "1",
 32    "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0",
 33}
 34
 35
 36def _assort_val_set(path, bac_type):
 37    image_paths = glob(os.path.join(path, bac_type, "training", "source", "*"))
 38    image_paths = [os.path.split(_path)[-1] for _path in image_paths]
 39
 40    val_partition = 0.2
 41    # let's get a balanced set of bacterias, if bac_type is mixed
 42    if bac_type == "mixed":
 43        _sort_1, _sort_2, _sort_3 = [], [], []
 44        for _path in image_paths:
 45            if _path.startswith("JE2"):
 46                _sort_1.append(_path)
 47            elif _path.startswith("pos"):
 48                _sort_2.append(_path)
 49            elif _path.startswith("train_"):
 50                _sort_3.append(_path)
 51
 52        val_image_paths = [
 53            *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False),
 54            *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False),
 55            *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False)
 56        ]
 57    else:
 58        val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False)
 59
 60    val_image_dir = os.path.join(path, bac_type, "val", "source")
 61    val_label_dir = os.path.join(path, bac_type, "val", "target")
 62    os.makedirs(val_image_dir, exist_ok=True)
 63    os.makedirs(val_label_dir, exist_ok=True)
 64
 65    for sample_id in val_image_paths:
 66        src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id)
 67        dst_val_image_path = os.path.join(val_image_dir, sample_id)
 68        shutil.move(src_val_image_path, dst_val_image_path)
 69
 70        src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id)
 71        dst_val_label_path = os.path.join(val_label_dir, sample_id)
 72        shutil.move(src_val_label_path, dst_val_label_path)
 73
 74
 75def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 76    f"""Download the DeepBacs training data.
 77
 78    Args:
 79        path: Filepath to a folder where the downloaded data will be saved.
 80        bac_type: The bacteria type. The available types are:
 81            {', '.join(URLS.keys())}
 82        download: Whether to download the data if it is not present.
 83
 84    Returns:
 85        The filepath to the training data.
 86    """
 87    bac_types = list(URLS.keys())
 88    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
 89
 90    data_folder = os.path.join(path, bac_type)
 91    if os.path.exists(data_folder):
 92        return data_folder
 93
 94    os.makedirs(path, exist_ok=True)
 95    zip_path = os.path.join(path, f"{bac_type}.zip")
 96    if not os.path.exists(zip_path):
 97        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
 98    util.unzip(zip_path, os.path.join(path, bac_type))
 99
100    # e_coli_stationary ships its own train/test splits; no val-splitting needed.
101    if bac_type != "e_coli_stationary":
102        _assort_val_set(path, bac_type)
103    return data_folder
104
105
106def get_deepbacs_paths(
107    path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False
108) -> Tuple[str, str]:
109    f"""Get paths to the DeepBacs data.
110
111    Args:
112        path: Filepath to a folder where the downloaded data will be saved.
113        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
114        bac_type: The bacteria type. The available types are:
115            {', '.join(URLS.keys())}
116        download: Whether to download the data if it is not present.
117
118    Returns:
119        Filepath to the folder where image data is stored.
120        Filepath to the folder where label data is stored.
121    """
122    get_deepbacs_data(path, bac_type, download)
123
124    # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
125    # mixed is the combination of all other types
126    if split == "train":
127        dir_choice = "training"
128    else:
129        dir_choice = split
130
131    if bac_type == "e_coli_stationary":
132        if split == "val":
133            raise NotImplementedError("The e_coli_stationary dataset does not have a val split.")
134        from natsort import natsorted
135        image_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "brightfield", "*.tif")))
136        label_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "masks", "*.tif")))
137    elif bac_type != "mixed":
138        raise NotImplementedError(
139            f"Currently only 'mixed' and 'e_coli_stationary' are supported, not {bac_type}"
140        )
141    else:
142        image_folder = os.path.join(path, bac_type, dir_choice, "source")
143        label_folder = os.path.join(path, bac_type, dir_choice, "target")
144
145    return image_folder, label_folder
146
147
148def get_deepbacs_dataset(
149    path: Union[os.PathLike, str],
150    split: str,
151    patch_shape: Tuple[int, int],
152    bac_type: str = "mixed",
153    download: bool = False,
154    **kwargs
155) -> Dataset:
156    f"""Get the DeepBacs dataset for bacteria segmentation.
157
158    Args:
159        path: Filepath to a folder where the downloaded data will be saved.
160        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
161        patch_shape: The patch shape to use for training.
162        bac_type: The bacteria type. The available types are:
163            {', '.join(URLS.keys())}
164        download: Whether to download the data if it is not present.
165        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
166
167    Returns:
168       The segmentation dataset.
169    """
170    assert split in ("train", "val", "test")
171
172    image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download)
173
174    # e_coli_stationary returns explicit file lists; mixed returns folder+glob strings.
175    raw_key = None if isinstance(image_folder, list) else "*.tif"
176    label_key = None if isinstance(label_folder, list) else "*.tif"
177    return torch_em.default_segmentation_dataset(
178        raw_paths=image_folder,
179        raw_key=raw_key,
180        label_paths=label_folder,
181        label_key=label_key,
182        patch_shape=patch_shape,
183        **kwargs
184    )
185
186
187def get_deepbacs_loader(
188    path: Union[os.PathLike, str],
189    split: str,
190    patch_shape: Tuple[int, int],
191    batch_size: int,
192    bac_type: str = "mixed",
193    download: bool = False,
194    **kwargs
195) -> DataLoader:
196    f"""Get the DeepBacs dataset for bacteria segmentation.
197
198    Args:
199        path: Filepath to a folder where the downloaded data will be saved.
200        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
201        patch_shape: The patch shape to use for training.
202        batch_size: The batch size for training.
203        bac_type: The bacteria type. The available types are:
204            {', '.join(URLS.keys())}
205        download: Whether to download the data if it is not present.
206        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
207
208    Returns:
209        The DataLoader.
210    """
211    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
212    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
213    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'e_coli_stationary': 'https://zenodo.org/records/6400327/files/DeepBacs_Data_Segmentation_Ecoli_stationary_phase.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS = {'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'e_coli_stationary': None, 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 76def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 77    f"""Download the DeepBacs training data.
 78
 79    Args:
 80        path: Filepath to a folder where the downloaded data will be saved.
 81        bac_type: The bacteria type. The available types are:
 82            {', '.join(URLS.keys())}
 83        download: Whether to download the data if it is not present.
 84
 85    Returns:
 86        The filepath to the training data.
 87    """
 88    bac_types = list(URLS.keys())
 89    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
 90
 91    data_folder = os.path.join(path, bac_type)
 92    if os.path.exists(data_folder):
 93        return data_folder
 94
 95    os.makedirs(path, exist_ok=True)
 96    zip_path = os.path.join(path, f"{bac_type}.zip")
 97    if not os.path.exists(zip_path):
 98        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
 99    util.unzip(zip_path, os.path.join(path, bac_type))
100
101    # e_coli_stationary ships its own train/test splits; no val-splitting needed.
102    if bac_type != "e_coli_stationary":
103        _assort_val_set(path, bac_type)
104    return data_folder
def get_deepbacs_paths( path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False) -> Tuple[str, str]:
107def get_deepbacs_paths(
108    path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False
109) -> Tuple[str, str]:
110    f"""Get paths to the DeepBacs data.
111
112    Args:
113        path: Filepath to a folder where the downloaded data will be saved.
114        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
115        bac_type: The bacteria type. The available types are:
116            {', '.join(URLS.keys())}
117        download: Whether to download the data if it is not present.
118
119    Returns:
120        Filepath to the folder where image data is stored.
121        Filepath to the folder where label data is stored.
122    """
123    get_deepbacs_data(path, bac_type, download)
124
125    # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
126    # mixed is the combination of all other types
127    if split == "train":
128        dir_choice = "training"
129    else:
130        dir_choice = split
131
132    if bac_type == "e_coli_stationary":
133        if split == "val":
134            raise NotImplementedError("The e_coli_stationary dataset does not have a val split.")
135        from natsort import natsorted
136        image_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "brightfield", "*.tif")))
137        label_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "masks", "*.tif")))
138    elif bac_type != "mixed":
139        raise NotImplementedError(
140            f"Currently only 'mixed' and 'e_coli_stationary' are supported, not {bac_type}"
141        )
142    else:
143        image_folder = os.path.join(path, bac_type, dir_choice, "source")
144        label_folder = os.path.join(path, bac_type, dir_choice, "target")
145
146    return image_folder, label_folder
def get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
149def get_deepbacs_dataset(
150    path: Union[os.PathLike, str],
151    split: str,
152    patch_shape: Tuple[int, int],
153    bac_type: str = "mixed",
154    download: bool = False,
155    **kwargs
156) -> Dataset:
157    f"""Get the DeepBacs dataset for bacteria segmentation.
158
159    Args:
160        path: Filepath to a folder where the downloaded data will be saved.
161        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
162        patch_shape: The patch shape to use for training.
163        bac_type: The bacteria type. The available types are:
164            {', '.join(URLS.keys())}
165        download: Whether to download the data if it is not present.
166        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
167
168    Returns:
169       The segmentation dataset.
170    """
171    assert split in ("train", "val", "test")
172
173    image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download)
174
175    # e_coli_stationary returns explicit file lists; mixed returns folder+glob strings.
176    raw_key = None if isinstance(image_folder, list) else "*.tif"
177    label_key = None if isinstance(label_folder, list) else "*.tif"
178    return torch_em.default_segmentation_dataset(
179        raw_paths=image_folder,
180        raw_key=raw_key,
181        label_paths=label_folder,
182        label_key=label_key,
183        patch_shape=patch_shape,
184        **kwargs
185    )
def get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
188def get_deepbacs_loader(
189    path: Union[os.PathLike, str],
190    split: str,
191    patch_shape: Tuple[int, int],
192    batch_size: int,
193    bac_type: str = "mixed",
194    download: bool = False,
195    **kwargs
196) -> DataLoader:
197    f"""Get the DeepBacs dataset for bacteria segmentation.
198
199    Args:
200        path: Filepath to a folder where the downloaded data will be saved.
201        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
202        patch_shape: The patch shape to use for training.
203        batch_size: The batch size for training.
204        bac_type: The bacteria type. The available types are:
205            {', '.join(URLS.keys())}
206        download: Whether to download the data if it is not present.
207        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
208
209    Returns:
210        The DataLoader.
211    """
212    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
213    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
214    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)