torch_em.data.datasets.light_microscopy.deepbacs
DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.
1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy. 2 3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. 4Please cite it if you use this dataset in your research. 5""" 6 7import os 8import shutil 9from glob import glob 10from typing import Tuple, Union 11 12import numpy as np 13 14from torch.utils.data import Dataset, DataLoader 15 16import torch_em 17 18from .. import util 19 20URLS = { 21 "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1", # noqa 22 "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1", # noqa 23 "e_coli_stationary": "https://zenodo.org/records/6400327/files/DeepBacs_Data_Segmentation_Ecoli_stationary_phase.zip?download=1", # noqa 24 "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1", 25 "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1", 26} 27CHECKSUMS = { 28 "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307", 29 "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17", 30 "e_coli_stationary": None, 31 "b_subtilis": "1", 32 "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0", 33} 34 35 36def _assort_val_set(path, bac_type): 37 image_paths = glob(os.path.join(path, bac_type, "training", "source", "*")) 38 image_paths = [os.path.split(_path)[-1] for _path in image_paths] 39 40 val_partition = 0.2 41 # let's get a balanced set of bacterias, if bac_type is mixed 42 if bac_type == "mixed": 43 _sort_1, _sort_2, _sort_3 = [], [], [] 44 for _path in image_paths: 45 if _path.startswith("JE2"): 46 _sort_1.append(_path) 47 elif _path.startswith("pos"): 48 _sort_2.append(_path) 49 elif _path.startswith("train_"): 50 _sort_3.append(_path) 51 52 val_image_paths = [ 53 *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False), 54 *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False), 55 *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False) 56 ] 57 else: 58 val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False) 59 60 val_image_dir = os.path.join(path, bac_type, "val", "source") 61 val_label_dir = os.path.join(path, bac_type, "val", "target") 62 os.makedirs(val_image_dir, exist_ok=True) 63 os.makedirs(val_label_dir, exist_ok=True) 64 65 for sample_id in val_image_paths: 66 src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id) 67 dst_val_image_path = os.path.join(val_image_dir, sample_id) 68 shutil.move(src_val_image_path, dst_val_image_path) 69 70 src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id) 71 dst_val_label_path = os.path.join(val_label_dir, sample_id) 72 shutil.move(src_val_label_path, dst_val_label_path) 73 74 75def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 76 f"""Download the DeepBacs training data. 77 78 Args: 79 path: Filepath to a folder where the downloaded data will be saved. 80 bac_type: The bacteria type. The available types are: 81 {', '.join(URLS.keys())} 82 download: Whether to download the data if it is not present. 83 84 Returns: 85 The filepath to the training data. 86 """ 87 bac_types = list(URLS.keys()) 88 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 89 90 data_folder = os.path.join(path, bac_type) 91 if os.path.exists(data_folder): 92 return data_folder 93 94 os.makedirs(path, exist_ok=True) 95 zip_path = os.path.join(path, f"{bac_type}.zip") 96 if not os.path.exists(zip_path): 97 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 98 util.unzip(zip_path, os.path.join(path, bac_type)) 99 100 # e_coli_stationary ships its own train/test splits; no val-splitting needed. 101 if bac_type != "e_coli_stationary": 102 _assort_val_set(path, bac_type) 103 return data_folder 104 105 106def get_deepbacs_paths( 107 path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False 108) -> Tuple[str, str]: 109 f"""Get paths to the DeepBacs data. 110 111 Args: 112 path: Filepath to a folder where the downloaded data will be saved. 113 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 114 bac_type: The bacteria type. The available types are: 115 {', '.join(URLS.keys())} 116 download: Whether to download the data if it is not present. 117 118 Returns: 119 Filepath to the folder where image data is stored. 120 Filepath to the folder where label data is stored. 121 """ 122 get_deepbacs_data(path, bac_type, download) 123 124 # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet 125 # mixed is the combination of all other types 126 if split == "train": 127 dir_choice = "training" 128 else: 129 dir_choice = split 130 131 if bac_type == "e_coli_stationary": 132 if split == "val": 133 raise NotImplementedError("The e_coli_stationary dataset does not have a val split.") 134 from natsort import natsorted 135 image_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "brightfield", "*.tif"))) 136 label_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "masks", "*.tif"))) 137 elif bac_type != "mixed": 138 raise NotImplementedError( 139 f"Currently only 'mixed' and 'e_coli_stationary' are supported, not {bac_type}" 140 ) 141 else: 142 image_folder = os.path.join(path, bac_type, dir_choice, "source") 143 label_folder = os.path.join(path, bac_type, dir_choice, "target") 144 145 return image_folder, label_folder 146 147 148def get_deepbacs_dataset( 149 path: Union[os.PathLike, str], 150 split: str, 151 patch_shape: Tuple[int, int], 152 bac_type: str = "mixed", 153 download: bool = False, 154 **kwargs 155) -> Dataset: 156 f"""Get the DeepBacs dataset for bacteria segmentation. 157 158 Args: 159 path: Filepath to a folder where the downloaded data will be saved. 160 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 161 patch_shape: The patch shape to use for training. 162 bac_type: The bacteria type. The available types are: 163 {', '.join(URLS.keys())} 164 download: Whether to download the data if it is not present. 165 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 166 167 Returns: 168 The segmentation dataset. 169 """ 170 assert split in ("train", "val", "test") 171 172 image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) 173 174 # e_coli_stationary returns explicit file lists; mixed returns folder+glob strings. 175 raw_key = None if isinstance(image_folder, list) else "*.tif" 176 label_key = None if isinstance(label_folder, list) else "*.tif" 177 return torch_em.default_segmentation_dataset( 178 raw_paths=image_folder, 179 raw_key=raw_key, 180 label_paths=label_folder, 181 label_key=label_key, 182 patch_shape=patch_shape, 183 **kwargs 184 ) 185 186 187def get_deepbacs_loader( 188 path: Union[os.PathLike, str], 189 split: str, 190 patch_shape: Tuple[int, int], 191 batch_size: int, 192 bac_type: str = "mixed", 193 download: bool = False, 194 **kwargs 195) -> DataLoader: 196 f"""Get the DeepBacs dataset for bacteria segmentation. 197 198 Args: 199 path: Filepath to a folder where the downloaded data will be saved. 200 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 201 patch_shape: The patch shape to use for training. 202 batch_size: The batch size for training. 203 bac_type: The bacteria type. The available types are: 204 {', '.join(URLS.keys())} 205 download: Whether to download the data if it is not present. 206 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 207 208 Returns: 209 The DataLoader. 210 """ 211 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 212 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 213 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'e_coli_stationary': 'https://zenodo.org/records/6400327/files/DeepBacs_Data_Segmentation_Ecoli_stationary_phase.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS =
{'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'e_coli_stationary': None, 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def
get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
76def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 77 f"""Download the DeepBacs training data. 78 79 Args: 80 path: Filepath to a folder where the downloaded data will be saved. 81 bac_type: The bacteria type. The available types are: 82 {', '.join(URLS.keys())} 83 download: Whether to download the data if it is not present. 84 85 Returns: 86 The filepath to the training data. 87 """ 88 bac_types = list(URLS.keys()) 89 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 90 91 data_folder = os.path.join(path, bac_type) 92 if os.path.exists(data_folder): 93 return data_folder 94 95 os.makedirs(path, exist_ok=True) 96 zip_path = os.path.join(path, f"{bac_type}.zip") 97 if not os.path.exists(zip_path): 98 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 99 util.unzip(zip_path, os.path.join(path, bac_type)) 100 101 # e_coli_stationary ships its own train/test splits; no val-splitting needed. 102 if bac_type != "e_coli_stationary": 103 _assort_val_set(path, bac_type) 104 return data_folder
def
get_deepbacs_paths( path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False) -> Tuple[str, str]:
107def get_deepbacs_paths( 108 path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False 109) -> Tuple[str, str]: 110 f"""Get paths to the DeepBacs data. 111 112 Args: 113 path: Filepath to a folder where the downloaded data will be saved. 114 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 115 bac_type: The bacteria type. The available types are: 116 {', '.join(URLS.keys())} 117 download: Whether to download the data if it is not present. 118 119 Returns: 120 Filepath to the folder where image data is stored. 121 Filepath to the folder where label data is stored. 122 """ 123 get_deepbacs_data(path, bac_type, download) 124 125 # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet 126 # mixed is the combination of all other types 127 if split == "train": 128 dir_choice = "training" 129 else: 130 dir_choice = split 131 132 if bac_type == "e_coli_stationary": 133 if split == "val": 134 raise NotImplementedError("The e_coli_stationary dataset does not have a val split.") 135 from natsort import natsorted 136 image_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "brightfield", "*.tif"))) 137 label_folder = natsorted(glob(os.path.join(path, bac_type, dir_choice, "masks", "*.tif"))) 138 elif bac_type != "mixed": 139 raise NotImplementedError( 140 f"Currently only 'mixed' and 'e_coli_stationary' are supported, not {bac_type}" 141 ) 142 else: 143 image_folder = os.path.join(path, bac_type, dir_choice, "source") 144 label_folder = os.path.join(path, bac_type, dir_choice, "target") 145 146 return image_folder, label_folder
def
get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
149def get_deepbacs_dataset( 150 path: Union[os.PathLike, str], 151 split: str, 152 patch_shape: Tuple[int, int], 153 bac_type: str = "mixed", 154 download: bool = False, 155 **kwargs 156) -> Dataset: 157 f"""Get the DeepBacs dataset for bacteria segmentation. 158 159 Args: 160 path: Filepath to a folder where the downloaded data will be saved. 161 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 162 patch_shape: The patch shape to use for training. 163 bac_type: The bacteria type. The available types are: 164 {', '.join(URLS.keys())} 165 download: Whether to download the data if it is not present. 166 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 167 168 Returns: 169 The segmentation dataset. 170 """ 171 assert split in ("train", "val", "test") 172 173 image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) 174 175 # e_coli_stationary returns explicit file lists; mixed returns folder+glob strings. 176 raw_key = None if isinstance(image_folder, list) else "*.tif" 177 label_key = None if isinstance(label_folder, list) else "*.tif" 178 return torch_em.default_segmentation_dataset( 179 raw_paths=image_folder, 180 raw_key=raw_key, 181 label_paths=label_folder, 182 label_key=label_key, 183 patch_shape=patch_shape, 184 **kwargs 185 )
def
get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
188def get_deepbacs_loader( 189 path: Union[os.PathLike, str], 190 split: str, 191 patch_shape: Tuple[int, int], 192 batch_size: int, 193 bac_type: str = "mixed", 194 download: bool = False, 195 **kwargs 196) -> DataLoader: 197 f"""Get the DeepBacs dataset for bacteria segmentation. 198 199 Args: 200 path: Filepath to a folder where the downloaded data will be saved. 201 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 202 patch_shape: The patch shape to use for training. 203 batch_size: The batch size for training. 204 bac_type: The bacteria type. The available types are: 205 {', '.join(URLS.keys())} 206 download: Whether to download the data if it is not present. 207 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 208 209 Returns: 210 The DataLoader. 211 """ 212 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 213 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 214 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)