torch_em.data.datasets.light_microscopy.hpa
This dataset was part of the HPA Kaggle challenge for protein identification. It contains confocal microscopy images and annotations for cell segmentation.
The dataset is described in the publication https://doi.org/10.1038/s41592-019-0658-6. Please cite it if you use this dataset in your research.
1"""This dataset was part of the HPA Kaggle challenge for protein identification. 2It contains confocal microscopy images and annotations for cell segmentation. 3 4The dataset is described in the publication https://doi.org/10.1038/s41592-019-0658-6. 5Please cite it if you use this dataset in your research. 6""" 7 8import os 9import json 10import shutil 11from glob import glob 12from tqdm import tqdm 13from concurrent import futures 14from functools import partial 15from typing import List, Optional, Sequence, Tuple, Union 16 17import imageio 18import numpy as np 19from skimage import morphology 20from PIL import Image, ImageDraw 21from skimage import draw as skimage_draw 22 23from torch.utils.data import Dataset, DataLoader 24 25import torch_em 26 27from .. import util 28 29 30URLS = { 31 "segmentation": "https://zenodo.org/record/4665863/files/hpa_dataset_v2.zip" 32} 33CHECKSUMS = { 34 "segmentation": "dcd6072293d88d49c71376d3d99f3f4f102e4ee83efb0187faa89c95ec49faa9" 35} 36VALID_CHANNELS = ["microtubules", "protein", "nuclei", "er"] 37 38 39def _download_hpa_data(path, name, download): 40 os.makedirs(path, exist_ok=True) 41 url = URLS[name] 42 checksum = CHECKSUMS[name] 43 zip_path = os.path.join(path, "data.zip") 44 util.download_source(zip_path, url, download=download, checksum=checksum) 45 util.unzip(zip_path, path, remove=True) 46 47 48def _load_features(features): 49 # Loop over list and create simple dictionary & get size of annotations 50 annot_dict = {} 51 skipped = [] 52 53 for feat_idx, feat in enumerate(features): 54 if feat["geometry"]["type"] not in ["Polygon", "LineString"]: 55 skipped.append(feat["geometry"]["type"]) 56 continue 57 58 # skip empty roi 59 if len(feat["geometry"]["coordinates"][0]) <= 0: 60 continue 61 62 key_annot = "annot_" + str(feat_idx) 63 annot_dict[key_annot] = {} 64 annot_dict[key_annot]["type"] = feat["geometry"]["type"] 65 annot_dict[key_annot]["pos"] = np.squeeze( 66 np.asarray(feat["geometry"]["coordinates"]) 67 ) 68 annot_dict[key_annot]["properties"] = feat["properties"] 69 70 # print("Skipped geometry type(s):", skipped) 71 return annot_dict 72 73 74def _generate_binary_masks(annot_dict, shape, erose_size=5, obj_size_rem=500, save_indiv=False): 75 # Get dimensions of image and created masks of same size 76 # This we need to save somewhere (e.g. as part of the geojson file?) 77 78 # Filled masks and edge mask for polygons 79 mask_fill = np.zeros(shape, dtype=np.uint8) 80 mask_edge = np.zeros(shape, dtype=np.uint8) 81 mask_labels = np.zeros(shape, dtype=np.uint16) 82 83 rr_all = [] 84 cc_all = [] 85 86 if save_indiv is True: 87 mask_edge_indiv = np.zeros( 88 (shape[0], shape[1], len(annot_dict)), dtype="bool" 89 ) 90 mask_fill_indiv = np.zeros( 91 (shape[0], shape[1], len(annot_dict)), dtype="bool" 92 ) 93 94 # Image used to draw lines - for edge mask for freelines 95 im_freeline = Image.new("1", (shape[1], shape[0]), color=0) 96 draw = ImageDraw.Draw(im_freeline) 97 98 # Loop over all roi 99 i_roi = 0 100 for roi_key, roi in annot_dict.items(): 101 roi_pos = roi["pos"] 102 103 # Check region type 104 105 # freeline - line 106 if roi["type"] == "freeline" or roi["type"] == "LineString": 107 108 # Loop over all pairs of points to draw the line 109 110 for ind in range(roi_pos.shape[0] - 1): 111 line_pos = ( 112 roi_pos[ind, 1], 113 roi_pos[ind, 0], 114 roi_pos[ind + 1, 1], 115 roi_pos[ind + 1, 0], 116 ) 117 draw.line(line_pos, fill=1, width=erose_size) 118 119 # freehand - polygon 120 elif ( 121 roi["type"] == "freehand" 122 or roi["type"] == "polygon" 123 or roi["type"] == "polyline" 124 or roi["type"] == "Polygon" 125 ): 126 127 # Draw polygon 128 rr, cc = skimage_draw.polygon( 129 [shape[0] - r for r in roi_pos[:, 1]], roi_pos[:, 0] 130 ) 131 132 # Make sure it's not outside 133 rr[rr < 0] = 0 134 rr[rr > shape[0] - 1] = shape[0] - 1 135 136 cc[cc < 0] = 0 137 cc[cc > shape[1] - 1] = shape[1] - 1 138 139 # Test if this region has already been added 140 if any(np.array_equal(rr, rr_test) for rr_test in rr_all) and any( 141 np.array_equal(cc, cc_test) for cc_test in cc_all 142 ): 143 # print('Region #{} has already been used'.format(i + 144 # 1)) 145 continue 146 147 rr_all.append(rr) 148 cc_all.append(cc) 149 150 # Generate mask 151 mask_fill_roi = np.zeros(shape, dtype=np.uint8) 152 mask_fill_roi[rr, cc] = 1 153 154 # Erode to get cell edge - both arrays are boolean to be used as 155 # index arrays later 156 mask_fill_roi_erode = morphology.binary_erosion( 157 mask_fill_roi, np.ones((erose_size, erose_size)) 158 ) 159 mask_edge_roi = ( 160 mask_fill_roi.astype("int") - mask_fill_roi_erode.astype("int") 161 ).astype("bool") 162 163 # Save array for mask and edge 164 mask_fill[mask_fill_roi > 0] = 1 165 mask_edge[mask_edge_roi] = 1 166 mask_labels[mask_fill_roi > 0] = i_roi + 1 167 168 if save_indiv is True: 169 mask_edge_indiv[:, :, i_roi] = mask_edge_roi.astype("bool") 170 mask_fill_indiv[:, :, i_roi] = mask_fill_roi_erode.astype("bool") 171 172 i_roi = i_roi + 1 173 174 else: 175 roi_type = roi["type"] 176 raise NotImplementedError( 177 f'Mask for roi type "{roi_type}" can not be created' 178 ) 179 180 del draw 181 182 # Convert mask from free-lines to numpy array 183 mask_edge_freeline = np.asarray(im_freeline) 184 mask_edge_freeline = mask_edge_freeline.astype("bool") 185 186 # Post-processing of fill and edge mask - if defined 187 mask_dict = {} 188 if np.any(mask_fill): 189 190 # (1) remove edges , (2) remove small objects 191 mask_fill = mask_fill & ~mask_edge 192 mask_fill = morphology.remove_small_objects( 193 mask_fill.astype("bool"), obj_size_rem 194 ) 195 196 # For edge - consider also freeline edge mask 197 198 mask_edge = mask_edge.astype("bool") 199 mask_edge = np.logical_or(mask_edge, mask_edge_freeline) 200 201 # Assign to dictionary for return 202 mask_dict["edge"] = mask_edge 203 mask_dict["fill"] = mask_fill.astype("bool") 204 mask_dict["labels"] = mask_labels.astype("uint16") 205 206 if save_indiv is True: 207 mask_dict["edge_indiv"] = mask_edge_indiv 208 mask_dict["fill_indiv"] = mask_fill_indiv 209 else: 210 mask_dict["edge_indiv"] = np.zeros(shape + (1,), dtype=np.uint8) 211 mask_dict["fill_indiv"] = np.zeros(shape + (1,), dtype=np.uint8) 212 213 # Only edge mask present 214 elif np.any(mask_edge_freeline): 215 mask_dict["edge"] = mask_edge_freeline 216 mask_dict["fill"] = mask_fill.astype("bool") 217 mask_dict["labels"] = mask_labels.astype("uint16") 218 219 mask_dict["edge_indiv"] = np.zeros(shape + (1,), dtype=np.uint8) 220 mask_dict["fill_indiv"] = np.zeros(shape + (1,), dtype=np.uint8) 221 222 else: 223 raise Exception("No mask has been created.") 224 225 return mask_dict 226 227 228# adapted from 229# https://github.com/imjoy-team/kaibu-utils/blob/main/kaibu_utils/__init__.py#L267 230def _get_labels(annotation_file, shape, label="*"): 231 with open(annotation_file) as f: 232 features = json.load(f)["features"] 233 if len(features) == 0: 234 return np.zeros(shape, dtype="uint16") 235 236 annot_dict_all = _load_features(features) 237 annot_types = set( 238 annot_dict_all[k]["properties"].get("label", "default") 239 for k in annot_dict_all.keys() 240 ) 241 for annot_type in annot_types: 242 if label and label != "*" and annot_type != label: 243 continue 244 # print("annot_type: ", annot_type) 245 # Filter the annotations by label 246 annot_dict = { 247 k: annot_dict_all[k] 248 for k in annot_dict_all.keys() 249 if label == "*" 250 or annot_dict_all[k]["properties"].get("label", "default") == annot_type 251 } 252 mask_dict = _generate_binary_masks( 253 annot_dict, shape, 254 erose_size=5, 255 obj_size_rem=500, 256 save_indiv=True, 257 ) 258 mask = mask_dict["labels"] 259 return mask 260 raise RuntimeError 261 262 263def _process_image(in_folder, out_path, with_labels): 264 import h5py 265 266 # TODO double check the default order and color matching 267 # correspondence to the HPA kaggle data: 268 # microtubules: red 269 # nuclei: blue 270 # er: yellow 271 # protein: green 272 # default order: rgby = micro, prot, nuclei, er 273 raw = np.concatenate([ 274 imageio.imread(os.path.join(in_folder, f"{chan}.png"))[None] for chan in VALID_CHANNELS 275 ], axis=0) 276 277 if with_labels: 278 annotation_file = os.path.join(in_folder, "annotation.json") 279 assert os.path.exists(annotation_file), annotation_file 280 labels = _get_labels(annotation_file, raw.shape[1:]) 281 assert labels.shape == raw.shape[1:] 282 283 with h5py.File(out_path, "w") as f: 284 f.create_dataset("raw/microtubules", data=raw[0], compression="gzip") 285 f.create_dataset("raw/protein", data=raw[1], compression="gzip") 286 f.create_dataset("raw/nuclei", data=raw[2], compression="gzip") 287 f.create_dataset("raw/er", data=raw[3], compression="gzip") 288 if with_labels: 289 f.create_dataset("labels", data=labels, compression="gzip") 290 291 292def _process_split(root_in, root_out, n_workers, with_labels): 293 os.makedirs(root_out, exist_ok=True) 294 inputs = glob(os.path.join(root_in, "*")) 295 outputs = [os.path.join(root_out, f"{os.path.split(inp)[1]}.h5") for inp in inputs] 296 process = partial(_process_image, with_labels=with_labels) 297 with futures.ProcessPoolExecutor(n_workers) as pp: 298 list(tqdm(pp.map(process, inputs, outputs), total=len(inputs), desc=f"Process data in {root_in}")) 299 300 301# save data as h5 in 4 separate channel raw data and labels extracted from the geo json 302def _process_hpa_data(path, n_workers, remove): 303 in_path = os.path.join(path, "hpa_dataset_v2") 304 assert os.path.exists(in_path), in_path 305 for split in ("train", "test", "valid"): 306 out_split = "val" if split == "valid" else split 307 _process_split( 308 root_in=os.path.join(in_path, split), 309 root_out=os.path.join(path, out_split), 310 n_workers=n_workers, 311 with_labels=(split != "test") 312 ) 313 if remove: 314 shutil.rmtree(in_path) 315 316 317def _check_data(path): 318 have_train = len(glob(os.path.join(path, "train", "*.h5"))) == 257 319 have_test = len(glob(os.path.join(path, "test", "*.h5"))) == 10 320 have_val = len(glob(os.path.join(path, "val", "*.h5"))) == 9 321 return have_train and have_test and have_val 322 323 324def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str: 325 """Download the HPA training data. 326 327 Args: 328 path: Filepath to a folder where the downloaded data will be saved. 329 download: Whether to download the data if it is not present. 330 n_workers_preproc: The number of workers to use for preprocessing. 331 332 Returns: 333 The filepath to the training data. 334 """ 335 data_is_complete = _check_data(path) 336 if not data_is_complete: 337 _download_hpa_data(path, "segmentation", download) 338 _process_hpa_data(path, n_workers_preproc, remove=True) 339 return path 340 341 342def get_hpa_segmentation_paths( 343 path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8, 344) -> List[str]: 345 """Get paths to the HPA data. 346 347 Args: 348 path: Filepath to a folder where the downloaded data will be saved. 349 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 350 download: Whether to download the data if it is not present. 351 n_workers_preproc: The number of workers to use for preprocessing. 352 353 Returns: 354 List of filepaths to the stored data. 355 """ 356 get_hpa_segmentation_data(path, download, n_workers_preproc) 357 paths = glob(os.path.join(path, split, "*.h5")) 358 return paths 359 360 361def get_hpa_segmentation_dataset( 362 path: Union[os.PathLike, str], 363 split: str, 364 patch_shape: Tuple[int, int], 365 offsets: Optional[List[List[int]]] = None, 366 boundaries: bool = False, 367 binary: bool = False, 368 channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], 369 download: bool = False, 370 n_workers_preproc: int = 8, 371 **kwargs 372) -> Dataset: 373 """Get the HPA dataset for segmenting cells in confocal microscopy. 374 375 Args: 376 path: Filepath to a folder where the downloaded data will be saved. 377 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 378 patch_shape: The patch shape to use for training. 379 offsets: Offset values for affinity computation used as target. 380 boundaries: Whether to compute boundaries as the target. 381 binary: Whether to use a binary segmentation target. 382 channels: The image channels to extract. Available channels are 383 'microtubules', 'protein', 'nuclei' or 'er'. 384 download: Whether to download the data if it is not present. 385 n_workers_preproc: The number of workers to use for preprocessing. 386 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 387 388 Returns: 389 The segmentation dataset. 390 """ 391 assert isinstance(channels, list), "The 'channels' argument expects the desired channel(s) in a list." 392 for chan in channels: 393 if chan not in VALID_CHANNELS: 394 raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.") 395 396 kwargs, _ = util.add_instance_label_transform( 397 kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets 398 ) 399 kwargs = util.update_kwargs(kwargs, "ndim", 2) 400 kwargs = util.update_kwargs(kwargs, "with_channels", True) 401 402 paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc) 403 404 return torch_em.default_segmentation_dataset( 405 raw_paths=paths, 406 raw_key=[f"raw/{chan}" for chan in channels], 407 label_paths=paths, 408 label_key="labels", 409 patch_shape=patch_shape, 410 **kwargs 411 ) 412 413 414def get_hpa_segmentation_loader( 415 path: Union[os.PathLike, str], 416 split: str, 417 patch_shape: Tuple[int, int], 418 batch_size: int, 419 offsets: Optional[List[List[int]]] = None, 420 boundaries: bool = False, 421 binary: bool = False, 422 channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], 423 download: bool = False, 424 n_workers_preproc: int = 8, 425 **kwargs 426) -> DataLoader: 427 """Get the HPA dataloader for segmenting cells in confocal microscopy. 428 429 Args: 430 path: Filepath to a folder where the downloaded data will be saved. 431 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 432 patch_shape: The patch shape to use for training. 433 batch_size: The batch size for training. 434 offsets: Offset values for affinity computation used as target. 435 boundaries: Whether to compute boundaries as the target. 436 binary: Whether to use a binary segmentation target. 437 channels: The image channels to extract. Available channels are 438 'microtubules', 'protein', 'nuclei' or 'er'. 439 download: Whether to download the data if it is not present. 440 n_workers_preproc: The number of workers to use for preprocessing. 441 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 442 443 Returns: 444 The DataLoader. 445 """ 446 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 447 dataset = get_hpa_segmentation_dataset( 448 path, split, patch_shape, 449 offsets=offsets, boundaries=boundaries, binary=binary, 450 channels=channels, download=download, n_workers_preproc=n_workers_preproc, 451 **ds_kwargs 452 ) 453 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'segmentation': 'https://zenodo.org/record/4665863/files/hpa_dataset_v2.zip'}
CHECKSUMS =
{'segmentation': 'dcd6072293d88d49c71376d3d99f3f4f102e4ee83efb0187faa89c95ec49faa9'}
VALID_CHANNELS =
['microtubules', 'protein', 'nuclei', 'er']
def
get_hpa_segmentation_data( path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str:
325def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str: 326 """Download the HPA training data. 327 328 Args: 329 path: Filepath to a folder where the downloaded data will be saved. 330 download: Whether to download the data if it is not present. 331 n_workers_preproc: The number of workers to use for preprocessing. 332 333 Returns: 334 The filepath to the training data. 335 """ 336 data_is_complete = _check_data(path) 337 if not data_is_complete: 338 _download_hpa_data(path, "segmentation", download) 339 _process_hpa_data(path, n_workers_preproc, remove=True) 340 return path
Download the HPA training data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- download: Whether to download the data if it is not present.
- n_workers_preproc: The number of workers to use for preprocessing.
Returns:
The filepath to the training data.
def
get_hpa_segmentation_paths( path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8) -> List[str]:
343def get_hpa_segmentation_paths( 344 path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8, 345) -> List[str]: 346 """Get paths to the HPA data. 347 348 Args: 349 path: Filepath to a folder where the downloaded data will be saved. 350 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 351 download: Whether to download the data if it is not present. 352 n_workers_preproc: The number of workers to use for preprocessing. 353 354 Returns: 355 List of filepaths to the stored data. 356 """ 357 get_hpa_segmentation_data(path, download, n_workers_preproc) 358 paths = glob(os.path.join(path, split, "*.h5")) 359 return paths
Get paths to the HPA data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
- download: Whether to download the data if it is not present.
- n_workers_preproc: The number of workers to use for preprocessing.
Returns:
List of filepaths to the stored data.
def
get_hpa_segmentation_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], offsets: Optional[List[List[int]]] = None, boundaries: bool = False, binary: bool = False, channels: Sequence[str] = ['microtubules', 'protein', 'nuclei', 'er'], download: bool = False, n_workers_preproc: int = 8, **kwargs) -> torch.utils.data.dataset.Dataset:
362def get_hpa_segmentation_dataset( 363 path: Union[os.PathLike, str], 364 split: str, 365 patch_shape: Tuple[int, int], 366 offsets: Optional[List[List[int]]] = None, 367 boundaries: bool = False, 368 binary: bool = False, 369 channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], 370 download: bool = False, 371 n_workers_preproc: int = 8, 372 **kwargs 373) -> Dataset: 374 """Get the HPA dataset for segmenting cells in confocal microscopy. 375 376 Args: 377 path: Filepath to a folder where the downloaded data will be saved. 378 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 379 patch_shape: The patch shape to use for training. 380 offsets: Offset values for affinity computation used as target. 381 boundaries: Whether to compute boundaries as the target. 382 binary: Whether to use a binary segmentation target. 383 channels: The image channels to extract. Available channels are 384 'microtubules', 'protein', 'nuclei' or 'er'. 385 download: Whether to download the data if it is not present. 386 n_workers_preproc: The number of workers to use for preprocessing. 387 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 388 389 Returns: 390 The segmentation dataset. 391 """ 392 assert isinstance(channels, list), "The 'channels' argument expects the desired channel(s) in a list." 393 for chan in channels: 394 if chan not in VALID_CHANNELS: 395 raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.") 396 397 kwargs, _ = util.add_instance_label_transform( 398 kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets 399 ) 400 kwargs = util.update_kwargs(kwargs, "ndim", 2) 401 kwargs = util.update_kwargs(kwargs, "with_channels", True) 402 403 paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc) 404 405 return torch_em.default_segmentation_dataset( 406 raw_paths=paths, 407 raw_key=[f"raw/{chan}" for chan in channels], 408 label_paths=paths, 409 label_key="labels", 410 patch_shape=patch_shape, 411 **kwargs 412 )
Get the HPA dataset for segmenting cells in confocal microscopy.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
- patch_shape: The patch shape to use for training.
- offsets: Offset values for affinity computation used as target.
- boundaries: Whether to compute boundaries as the target.
- binary: Whether to use a binary segmentation target.
- channels: The image channels to extract. Available channels are 'microtubules', 'protein', 'nuclei' or 'er'.
- download: Whether to download the data if it is not present.
- n_workers_preproc: The number of workers to use for preprocessing.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_hpa_segmentation_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, offsets: Optional[List[List[int]]] = None, boundaries: bool = False, binary: bool = False, channels: Sequence[str] = ['microtubules', 'protein', 'nuclei', 'er'], download: bool = False, n_workers_preproc: int = 8, **kwargs) -> torch.utils.data.dataloader.DataLoader:
415def get_hpa_segmentation_loader( 416 path: Union[os.PathLike, str], 417 split: str, 418 patch_shape: Tuple[int, int], 419 batch_size: int, 420 offsets: Optional[List[List[int]]] = None, 421 boundaries: bool = False, 422 binary: bool = False, 423 channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], 424 download: bool = False, 425 n_workers_preproc: int = 8, 426 **kwargs 427) -> DataLoader: 428 """Get the HPA dataloader for segmenting cells in confocal microscopy. 429 430 Args: 431 path: Filepath to a folder where the downloaded data will be saved. 432 split: The split for the dataset. Available splits are 'train', 'val' or 'test'. 433 patch_shape: The patch shape to use for training. 434 batch_size: The batch size for training. 435 offsets: Offset values for affinity computation used as target. 436 boundaries: Whether to compute boundaries as the target. 437 binary: Whether to use a binary segmentation target. 438 channels: The image channels to extract. Available channels are 439 'microtubules', 'protein', 'nuclei' or 'er'. 440 download: Whether to download the data if it is not present. 441 n_workers_preproc: The number of workers to use for preprocessing. 442 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 443 444 Returns: 445 The DataLoader. 446 """ 447 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 448 dataset = get_hpa_segmentation_dataset( 449 path, split, patch_shape, 450 offsets=offsets, boundaries=boundaries, binary=binary, 451 channels=channels, download=download, n_workers_preproc=n_workers_preproc, 452 **ds_kwargs 453 ) 454 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the HPA dataloader for segmenting cells in confocal microscopy.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
- patch_shape: The patch shape to use for training.
- batch_size: The batch size for training.
- offsets: Offset values for affinity computation used as target.
- boundaries: Whether to compute boundaries as the target.
- binary: Whether to use a binary segmentation target.
- channels: The image channels to extract. Available channels are 'microtubules', 'protein', 'nuclei' or 'er'.
- download: Whether to download the data if it is not present.
- n_workers_preproc: The number of workers to use for preprocessing.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.