torch_em.data.datasets.light_microscopy.deepbacs
DeepBacs is a dataset for segmenting bacteria in label-free light microscopy.
This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset in your research.
1"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy. 2 3This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. 4Please cite it if you use this dataset in your research. 5""" 6 7import os 8import shutil 9import numpy as np 10from glob import glob 11from typing import Tuple, Union 12 13import torch_em 14from torch.utils.data import Dataset, DataLoader 15from .. import util 16 17URLS = { 18 "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1", 19 "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1", 20 "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1", 21 "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1", 22} 23CHECKSUMS = { 24 "s_aureus": "4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307", 25 "e_coli": "f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17", 26 "b_subtilis": "1", 27 "mixed": "2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0", 28} 29 30 31def _assort_val_set(path, bac_type): 32 image_paths = glob(os.path.join(path, bac_type, "training", "source", "*")) 33 image_paths = [os.path.split(_path)[-1] for _path in image_paths] 34 35 val_partition = 0.2 36 # let's get a balanced set of bacterias, if bac_type is mixed 37 if bac_type == "mixed": 38 _sort_1, _sort_2, _sort_3 = [], [], [] 39 for _path in image_paths: 40 if _path.startswith("JE2"): 41 _sort_1.append(_path) 42 elif _path.startswith("pos"): 43 _sort_2.append(_path) 44 elif _path.startswith("train_"): 45 _sort_3.append(_path) 46 47 val_image_paths = [ 48 *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False), 49 *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False), 50 *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False) 51 ] 52 else: 53 val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False) 54 55 val_image_dir = os.path.join(path, bac_type, "val", "source") 56 val_label_dir = os.path.join(path, bac_type, "val", "target") 57 os.makedirs(val_image_dir, exist_ok=True) 58 os.makedirs(val_label_dir, exist_ok=True) 59 60 for sample_id in val_image_paths: 61 src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id) 62 dst_val_image_path = os.path.join(val_image_dir, sample_id) 63 shutil.move(src_val_image_path, dst_val_image_path) 64 65 src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id) 66 dst_val_label_path = os.path.join(val_label_dir, sample_id) 67 shutil.move(src_val_label_path, dst_val_label_path) 68 69 70def get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 71 f"""Download the DeepBacs training data. 72 73 Args: 74 path: Filepath to a folder where the downloaded data will be saved. 75 bac_type: The bacteria type. The available types are: 76 {', '.join(URLS.keys())} 77 download: Whether to download the data if it is not present. 78 79 Returns: 80 The filepath to the training data. 81 """ 82 bac_types = list(URLS.keys()) 83 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 84 85 data_folder = os.path.join(path, bac_type) 86 if os.path.exists(data_folder): 87 return data_folder 88 89 os.makedirs(path, exist_ok=True) 90 zip_path = os.path.join(path, f"{bac_type}.zip") 91 if not os.path.exists(zip_path): 92 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 93 util.unzip(zip_path, os.path.join(path, bac_type)) 94 95 # Get a val split for the expected bacteria type. 96 _assort_val_set(path, bac_type) 97 return data_folder 98 99 100def _get_paths(path, bac_type, split): 101 # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet 102 # mixed is the combination of all other types 103 if split == "train": 104 dir_choice = "training" 105 else: 106 dir_choice = split 107 108 if bac_type != "mixed": 109 raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") 110 image_folder = os.path.join(path, bac_type, dir_choice, "source") 111 label_folder = os.path.join(path, bac_type, dir_choice, "target") 112 return image_folder, label_folder 113 114 115def get_deepbacs_dataset( 116 path: Union[os.PathLike, str], 117 split: str, 118 patch_shape: Tuple[int, int], 119 bac_type: str = "mixed", 120 download: bool = False, 121 **kwargs 122) -> Dataset: 123 f"""Get the CTC dataset for cell segmentation. 124 125 Args: 126 path: Filepath to a folder where the downloaded data will be saved. 127 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 128 patch_shape: The patch shape to use for training. 129 bac_type: The bacteria type. The available types are: 130 {', '.join(URLS.keys())} 131 download: Whether to download the data if it is not present. 132 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 133 134 Returns: 135 The segmentation dataset. 136 """ 137 assert split in ("train", "val", "test") 138 get_deebacs_data(path, bac_type, download) 139 image_folder, label_folder = _get_paths(path, bac_type, split) 140 dataset = torch_em.default_segmentation_dataset( 141 image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs 142 ) 143 return dataset 144 145 146def get_deepbacs_loader( 147 path: Union[os.PathLike, str], 148 split: str, 149 patch_shape: Tuple[int, int], 150 batch_size: int, 151 bac_type: str = "mixed", 152 download: bool = False, 153 **kwargs 154) -> DataLoader: 155 f"""Get the CTC dataset for cell segmentation. 156 157 Args: 158 path: Filepath to a folder where the downloaded data will be saved. 159 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 160 patch_shape: The patch shape to use for training. 161 batch_size: The batch size for training. 162 bac_type: The bacteria type. The available types are: 163 {', '.join(URLS.keys())} 164 download: Whether to download the data if it is not present. 165 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 166 167 Returns: 168 The DataLoader. 169 """ 170 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 171 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 172 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 173 return loader
URLS =
{'s_aureus': 'https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1', 'e_coli': 'https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1', 'b_subtilis': 'https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1', 'mixed': 'https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1'}
CHECKSUMS =
{'s_aureus': '4047792f1248ee82fce34121d0ade84828e55db5a34656cc25beec46eacaf307', 'e_coli': 'f812a2f814c3875c78fcc1609a2e9b34c916c7a9911abbf8117f423536ef1c17', 'b_subtilis': '1', 'mixed': '2730e6b391637d6dc05bbc7b8c915fd8184d835ac3611e13f23ac6f10f86c2a0'}
def
get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str:
71def get_deebacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: 72 f"""Download the DeepBacs training data. 73 74 Args: 75 path: Filepath to a folder where the downloaded data will be saved. 76 bac_type: The bacteria type. The available types are: 77 {', '.join(URLS.keys())} 78 download: Whether to download the data if it is not present. 79 80 Returns: 81 The filepath to the training data. 82 """ 83 bac_types = list(URLS.keys()) 84 assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" 85 86 data_folder = os.path.join(path, bac_type) 87 if os.path.exists(data_folder): 88 return data_folder 89 90 os.makedirs(path, exist_ok=True) 91 zip_path = os.path.join(path, f"{bac_type}.zip") 92 if not os.path.exists(zip_path): 93 util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) 94 util.unzip(zip_path, os.path.join(path, bac_type)) 95 96 # Get a val split for the expected bacteria type. 97 _assort_val_set(path, bac_type) 98 return data_folder
def
get_deepbacs_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
116def get_deepbacs_dataset( 117 path: Union[os.PathLike, str], 118 split: str, 119 patch_shape: Tuple[int, int], 120 bac_type: str = "mixed", 121 download: bool = False, 122 **kwargs 123) -> Dataset: 124 f"""Get the CTC dataset for cell segmentation. 125 126 Args: 127 path: Filepath to a folder where the downloaded data will be saved. 128 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 129 patch_shape: The patch shape to use for training. 130 bac_type: The bacteria type. The available types are: 131 {', '.join(URLS.keys())} 132 download: Whether to download the data if it is not present. 133 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 134 135 Returns: 136 The segmentation dataset. 137 """ 138 assert split in ("train", "val", "test") 139 get_deebacs_data(path, bac_type, download) 140 image_folder, label_folder = _get_paths(path, bac_type, split) 141 dataset = torch_em.default_segmentation_dataset( 142 image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs 143 ) 144 return dataset
def
get_deepbacs_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, bac_type: str = 'mixed', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
147def get_deepbacs_loader( 148 path: Union[os.PathLike, str], 149 split: str, 150 patch_shape: Tuple[int, int], 151 batch_size: int, 152 bac_type: str = "mixed", 153 download: bool = False, 154 **kwargs 155) -> DataLoader: 156 f"""Get the CTC dataset for cell segmentation. 157 158 Args: 159 path: Filepath to a folder where the downloaded data will be saved. 160 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 161 patch_shape: The patch shape to use for training. 162 batch_size: The batch size for training. 163 bac_type: The bacteria type. The available types are: 164 {', '.join(URLS.keys())} 165 download: Whether to download the data if it is not present. 166 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 167 168 Returns: 169 The DataLoader. 170 """ 171 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 172 dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) 173 loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) 174 return loader