torch_em.data.datasets.light_microscopy.deepbacs

DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.

This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.

  1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
  2
  3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z.
  4Please cite it if you use this dataset in your research.
  5"""
  6
  7import os
  8import shutil
  9from glob import glob
 10from typing import Tuple, Union
 11
 12import numpy as np
 13
 14from torch.utils.data import Dataset, DataLoader
 15
 16import torch_em
 17
 18from .. import util
 19
 20URLS = {
 21    "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1",  # noqa
 22    "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1",  # noqa
 23    "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1",
 24    "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1",
 25}
 26CHECKSUMS = {
 27    "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307",
 28    "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17",
 29    "b_subtilis": "1",
 30    "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0",
 31}
 32
 33
 34def _assort_val_set(path, bac_type):
 35    image_paths = glob(os.path.join(path, bac_type, "training", "source", "*"))
 36    image_paths = [os.path.split(_path)[-1] for _path in image_paths]
 37
 38    val_partition = 0.2
 39    # let's get a balanced set of bacterias, if bac_type is mixed
 40    if bac_type == "mixed":
 41        _sort_1, _sort_2, _sort_3 = [], [], []
 42        for _path in image_paths:
 43            if _path.startswith("JE2"):
 44                _sort_1.append(_path)
 45            elif _path.startswith("pos"):
 46                _sort_2.append(_path)
 47            elif _path.startswith("train_"):
 48                _sort_3.append(_path)
 49
 50        val_image_paths = [
 51            *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False),
 52            *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False),
 53            *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False)
 54        ]
 55    else:
 56        val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False)
 57
 58    val_image_dir = os.path.join(path, bac_type, "val", "source")
 59    val_label_dir = os.path.join(path, bac_type, "val", "target")
 60    os.makedirs(val_image_dir, exist_ok=True)
 61    os.makedirs(val_label_dir, exist_ok=True)
 62
 63    for sample_id in val_image_paths:
 64        src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id)
 65        dst_val_image_path = os.path.join(val_image_dir, sample_id)
 66        shutil.move(src_val_image_path, dst_val_image_path)
 67
 68        src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id)
 69        dst_val_label_path = os.path.join(val_label_dir, sample_id)
 70        shutil.move(src_val_label_path, dst_val_label_path)
 71
 72
 73def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 74    f"""Download the DeepBacs training data.
 75
 76    Args:
 77        path: Filepath to a folder where the downloaded data will be saved.
 78        bac_type: The bacteria type. The available types are:
 79            {', '.join(URLS.keys())}
 80        download: Whether to download the data if it is not present.
 81
 82    Returns:
 83        The filepath to the training data.
 84    """
 85    bac_types = list(URLS.keys())
 86    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
 87
 88    data_folder = os.path.join(path, bac_type)
 89    if os.path.exists(data_folder):
 90        return data_folder
 91
 92    os.makedirs(path, exist_ok=True)
 93    zip_path = os.path.join(path, f"{bac_type}.zip")
 94    if not os.path.exists(zip_path):
 95        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
 96    util.unzip(zip_path, os.path.join(path, bac_type))
 97
 98    # Get a val split for the expected bacteria type.
 99    _assort_val_set(path, bac_type)
100    return data_folder
101
102
103def get_deepbacs_paths(
104    path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False
105) -> Tuple[str, str]:
106    f"""Get paths to the DeepBacs data.
107
108    Args:
109        path: Filepath to a folder where the downloaded data will be saved.
110        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
111        bac_type: The bacteria type. The available types are:
112            {', '.join(URLS.keys())}
113        download: Whether to download the data if it is not present.
114
115    Returns:
116        Filepath to the folder where image data is stored.
117        Filepath to the folder where label data is stored.
118    """
119    get_deepbacs_data(path, bac_type, download)
120
121    # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
122    # mixed is the combination of all other types
123    if split == "train":
124        dir_choice = "training"
125    else:
126        dir_choice = split
127
128    if bac_type != "mixed":
129        raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}")
130
131    image_folder = os.path.join(path, bac_type, dir_choice, "source")
132    label_folder = os.path.join(path, bac_type, dir_choice, "target")
133
134    return image_folder, label_folder
135
136
137def get_deepbacs_dataset(
138    path: Union[os.PathLike, str],
139    split: str,
140    patch_shape: Tuple[int, int],
141    bac_type: str = "mixed",
142    download: bool = False,
143    **kwargs
144) -> Dataset:
145    f"""Get the DeepBacs dataset for bacteria segmentation.
146
147    Args:
148        path: Filepath to a folder where the downloaded data will be saved.
149        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
150        patch_shape: The patch shape to use for training.
151        bac_type: The bacteria type. The available types are:
152            {', '.join(URLS.keys())}
153        download: Whether to download the data if it is not present.
154        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
155
156    Returns:
157       The segmentation dataset.
158    """
159    assert split in ("train", "val", "test")
160
161    image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download)
162
163    return torch_em.default_segmentation_dataset(
164        raw_paths=image_folder,
165        raw_key="*.tif",
166        label_paths=label_folder,
167        label_key="*.tif",
168        patch_shape=patch_shape,
169        **kwargs
170    )
171
172
173def get_deepbacs_loader(
174    path: Union[os.PathLike, str],
175    split: str,
176    patch_shape: Tuple[int, int],
177    batch_size: int,
178    bac_type: str = "mixed",
179    download: bool = False,
180    **kwargs
181) -> DataLoader:
182    f"""Get the DeepBacs dataset for bacteria segmentation.
183
184    Args:
185        path: Filepath to a folder where the downloaded data will be saved.
186        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
187        patch_shape: The patch shape to use for training.
188        batch_size: The batch size for training.
189        bac_type: The bacteria type. The available types are:
190            {', '.join(URLS.keys())}
191        download: Whether to download the data if it is not present.
192        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
193
194    Returns:
195        The DataLoader.
196    """
197    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
198    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
199    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS = {'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 74def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
 75    f"""Download the DeepBacs training data.
 76
 77    Args:
 78        path: Filepath to a folder where the downloaded data will be saved.
 79        bac_type: The bacteria type. The available types are:
 80            {', '.join(URLS.keys())}
 81        download: Whether to download the data if it is not present.
 82
 83    Returns:
 84        The filepath to the training data.
 85    """
 86    bac_types = list(URLS.keys())
 87    assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"
 88
 89    data_folder = os.path.join(path, bac_type)
 90    if os.path.exists(data_folder):
 91        return data_folder
 92
 93    os.makedirs(path, exist_ok=True)
 94    zip_path = os.path.join(path, f"{bac_type}.zip")
 95    if not os.path.exists(zip_path):
 96        util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
 97    util.unzip(zip_path, os.path.join(path, bac_type))
 98
 99    # Get a val split for the expected bacteria type.
100    _assort_val_set(path, bac_type)
101    return data_folder
def get_deepbacs_paths( path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False) -> Tuple[str, str]:
104def get_deepbacs_paths(
105    path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False
106) -> Tuple[str, str]:
107    f"""Get paths to the DeepBacs data.
108
109    Args:
110        path: Filepath to a folder where the downloaded data will be saved.
111        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
112        bac_type: The bacteria type. The available types are:
113            {', '.join(URLS.keys())}
114        download: Whether to download the data if it is not present.
115
116    Returns:
117        Filepath to the folder where image data is stored.
118        Filepath to the folder where label data is stored.
119    """
120    get_deepbacs_data(path, bac_type, download)
121
122    # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
123    # mixed is the combination of all other types
124    if split == "train":
125        dir_choice = "training"
126    else:
127        dir_choice = split
128
129    if bac_type != "mixed":
130        raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}")
131
132    image_folder = os.path.join(path, bac_type, dir_choice, "source")
133    label_folder = os.path.join(path, bac_type, dir_choice, "target")
134
135    return image_folder, label_folder
def get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
138def get_deepbacs_dataset(
139    path: Union[os.PathLike, str],
140    split: str,
141    patch_shape: Tuple[int, int],
142    bac_type: str = "mixed",
143    download: bool = False,
144    **kwargs
145) -> Dataset:
146    f"""Get the DeepBacs dataset for bacteria segmentation.
147
148    Args:
149        path: Filepath to a folder where the downloaded data will be saved.
150        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
151        patch_shape: The patch shape to use for training.
152        bac_type: The bacteria type. The available types are:
153            {', '.join(URLS.keys())}
154        download: Whether to download the data if it is not present.
155        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
156
157    Returns:
158       The segmentation dataset.
159    """
160    assert split in ("train", "val", "test")
161
162    image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download)
163
164    return torch_em.default_segmentation_dataset(
165        raw_paths=image_folder,
166        raw_key="*.tif",
167        label_paths=label_folder,
168        label_key="*.tif",
169        patch_shape=patch_shape,
170        **kwargs
171    )
def get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
174def get_deepbacs_loader(
175    path: Union[os.PathLike, str],
176    split: str,
177    patch_shape: Tuple[int, int],
178    batch_size: int,
179    bac_type: str = "mixed",
180    download: bool = False,
181    **kwargs
182) -> DataLoader:
183    f"""Get the DeepBacs dataset for bacteria segmentation.
184
185    Args:
186        path: Filepath to a folder where the downloaded data will be saved.
187        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
188        patch_shape: The patch shape to use for training.
189        batch_size: The batch size for training.
190        bac_type: The bacteria type. The available types are:
191            {', '.join(URLS.keys())}
192        download: Whether to download the data if it is not present.
193        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
194
195    Returns:
196        The DataLoader.
197    """
198    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
199    dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs)
200    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)