torch_em.data.datasets.light_microscopy.deepbacs
DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.
1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy. 2 3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. 4Please cite it if you use this dataset in your research. 5""" 6 7import os 8import shutil 9from glob import glob 10from typing import Tuple, Union 11 12import numpy as np 13 14from torch.utils.data import Dataset, DataLoader 15 16import torch_em 17 18from .. import util 19 20URLS = { 21 "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1", # noqa 22 "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1", # noqa 23 "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1", 24 "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1", 25} 26CHECKSUMS = { 27 "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307", 28 "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17", 29 "b_subtilis": "1", 30 "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0", 31} 32 33 34def _assort_val_set(path, bac_type): 35 image_paths = glob(os.path.join(path, bac_type, "training", "source", "*")) 36 image_paths = [os.path.split(_path)[-1] for _path in image_paths] 37 38 val_partition = 0.2 39 # let's get a balanced set of bacterias, if bac_type is mixed 40 if bac_type == "mixed": 41 _sort_1, _sort_2, _sort_3 = [], [], [] 42 for _path in image_paths: 43 if _path.startswith("JE2"): 44 _sort_1.append(_path) 45 elif _path.startswith("pos"): 46 _sort_2.append(_path) 47 elif _path.startswith("train_"): 48 _sort_3.append(_path) 49 50 val_image_paths = [ 51 *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False), 52 *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False), 53 *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False) 54 ] 55 else: 56 val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False) 57 58 val_image_dir = os.path.join(path, bac_type, "val", "source") 59 val_label_dir = os.path.join(path, bac_type, "val", "target") 60 os.makedirs(val_image_dir, exist_ok=True) 61 os.makedirs(val_label_dir, exist_ok=True) 62 63 for sample_id in val_image_paths: 64 src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id) 65 dst_val_image_path = os.path.join(val_image_dir, sample_id) 66 shutil.move(src_val_image_path, dst_val_image_path) 67 68 src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id) 69 dst_val_label_path = os.path.join(val_label_dir, sample_id) 70 shutil.move(src_val_label_path, dst_val_label_path) 71 72 73def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 74 f"""Download the DeepBacs training data. 75 76 Args: 77 path: Filepath to a folder where the downloaded data will be saved. 78 bac_type: The bacteria type. The available types are: 79 {', '.join(URLS.keys())} 80 download: Whether to download the data if it is not present. 81 82 Returns: 83 The filepath to the training data. 84 """ 85 bac_types = list(URLS.keys()) 86 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 87 88 data_folder = os.path.join(path, bac_type) 89 if os.path.exists(data_folder): 90 return data_folder 91 92 os.makedirs(path, exist_ok=True) 93 zip_path = os.path.join(path, f"{bac_type}.zip") 94 if not os.path.exists(zip_path): 95 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 96 util.unzip(zip_path, os.path.join(path, bac_type)) 97 98 # Get a val split for the expected bacteria type. 99 _assort_val_set(path, bac_type) 100 return data_folder 101 102 103def get_deepbacs_paths( 104 path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False 105) -> Tuple[str, str]: 106 f"""Get paths to the DeepBacs data. 107 108 Args: 109 path: Filepath to a folder where the downloaded data will be saved. 110 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 111 bac_type: The bacteria type. The available types are: 112 {', '.join(URLS.keys())} 113 download: Whether to download the data if it is not present. 114 115 Returns: 116 Filepath to the folder where image data is stored. 117 Filepath to the folder where label data is stored. 118 """ 119 get_deepbacs_data(path, bac_type, download) 120 121 # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet 122 # mixed is the combination of all other types 123 if split == "train": 124 dir_choice = "training" 125 else: 126 dir_choice = split 127 128 if bac_type != "mixed": 129 raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") 130 131 image_folder = os.path.join(path, bac_type, dir_choice, "source") 132 label_folder = os.path.join(path, bac_type, dir_choice, "target") 133 134 return image_folder, label_folder 135 136 137def get_deepbacs_dataset( 138 path: Union[os.PathLike, str], 139 split: str, 140 patch_shape: Tuple[int, int], 141 bac_type: str = "mixed", 142 download: bool = False, 143 **kwargs 144) -> Dataset: 145 f"""Get the DeepBacs dataset for bacteria segmentation. 146 147 Args: 148 path: Filepath to a folder where the downloaded data will be saved. 149 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 150 patch_shape: The patch shape to use for training. 151 bac_type: The bacteria type. The available types are: 152 {', '.join(URLS.keys())} 153 download: Whether to download the data if it is not present. 154 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 155 156 Returns: 157 The segmentation dataset. 158 """ 159 assert split in ("train", "val", "test") 160 161 image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) 162 163 return torch_em.default_segmentation_dataset( 164 raw_paths=image_folder, 165 raw_key="*.tif", 166 label_paths=label_folder, 167 label_key="*.tif", 168 patch_shape=patch_shape, 169 **kwargs 170 ) 171 172 173def get_deepbacs_loader( 174 path: Union[os.PathLike, str], 175 split: str, 176 patch_shape: Tuple[int, int], 177 batch_size: int, 178 bac_type: str = "mixed", 179 download: bool = False, 180 **kwargs 181) -> DataLoader: 182 f"""Get the DeepBacs dataset for bacteria segmentation. 183 184 Args: 185 path: Filepath to a folder where the downloaded data will be saved. 186 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 187 patch_shape: The patch shape to use for training. 188 batch_size: The batch size for training. 189 bac_type: The bacteria type. The available types are: 190 {', '.join(URLS.keys())} 191 download: Whether to download the data if it is not present. 192 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 193 194 Returns: 195 The DataLoader. 196 """ 197 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 198 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 199 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS =
{'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def
get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
74def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 75 f"""Download the DeepBacs training data. 76 77 Args: 78 path: Filepath to a folder where the downloaded data will be saved. 79 bac_type: The bacteria type. The available types are: 80 {', '.join(URLS.keys())} 81 download: Whether to download the data if it is not present. 82 83 Returns: 84 The filepath to the training data. 85 """ 86 bac_types = list(URLS.keys()) 87 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 88 89 data_folder = os.path.join(path, bac_type) 90 if os.path.exists(data_folder): 91 return data_folder 92 93 os.makedirs(path, exist_ok=True) 94 zip_path = os.path.join(path, f"{bac_type}.zip") 95 if not os.path.exists(zip_path): 96 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 97 util.unzip(zip_path, os.path.join(path, bac_type)) 98 99 # Get a val split for the expected bacteria type. 100 _assort_val_set(path, bac_type) 101 return data_folder
def
get_deepbacs_paths( path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False) -> Tuple[str, str]:
104def get_deepbacs_paths( 105 path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False 106) -> Tuple[str, str]: 107 f"""Get paths to the DeepBacs data. 108 109 Args: 110 path: Filepath to a folder where the downloaded data will be saved. 111 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 112 bac_type: The bacteria type. The available types are: 113 {', '.join(URLS.keys())} 114 download: Whether to download the data if it is not present. 115 116 Returns: 117 Filepath to the folder where image data is stored. 118 Filepath to the folder where label data is stored. 119 """ 120 get_deepbacs_data(path, bac_type, download) 121 122 # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet 123 # mixed is the combination of all other types 124 if split == "train": 125 dir_choice = "training" 126 else: 127 dir_choice = split 128 129 if bac_type != "mixed": 130 raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") 131 132 image_folder = os.path.join(path, bac_type, dir_choice, "source") 133 label_folder = os.path.join(path, bac_type, dir_choice, "target") 134 135 return image_folder, label_folder
def
get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
138def get_deepbacs_dataset( 139 path: Union[os.PathLike, str], 140 split: str, 141 patch_shape: Tuple[int, int], 142 bac_type: str = "mixed", 143 download: bool = False, 144 **kwargs 145) -> Dataset: 146 f"""Get the DeepBacs dataset for bacteria segmentation. 147 148 Args: 149 path: Filepath to a folder where the downloaded data will be saved. 150 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 151 patch_shape: The patch shape to use for training. 152 bac_type: The bacteria type. The available types are: 153 {', '.join(URLS.keys())} 154 download: Whether to download the data if it is not present. 155 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 156 157 Returns: 158 The segmentation dataset. 159 """ 160 assert split in ("train", "val", "test") 161 162 image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) 163 164 return torch_em.default_segmentation_dataset( 165 raw_paths=image_folder, 166 raw_key="*.tif", 167 label_paths=label_folder, 168 label_key="*.tif", 169 patch_shape=patch_shape, 170 **kwargs 171 )
def
get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
174def get_deepbacs_loader( 175 path: Union[os.PathLike, str], 176 split: str, 177 patch_shape: Tuple[int, int], 178 batch_size: int, 179 bac_type: str = "mixed", 180 download: bool = False, 181 **kwargs 182) -> DataLoader: 183 f"""Get the DeepBacs dataset for bacteria segmentation. 184 185 Args: 186 path: Filepath to a folder where the downloaded data will be saved. 187 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 188 patch_shape: The patch shape to use for training. 189 batch_size: The batch size for training. 190 bac_type: The bacteria type. The available types are: 191 {', '.join(URLS.keys())} 192 download: Whether to download the data if it is not present. 193 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 194 195 Returns: 196 The DataLoader. 197 """ 198 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 199 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 200 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)