torch_em.data.datasets.light_microscopy.hpa

This dataset was part of the HPA Kaggle challenge for protein identification. It contains confocal microscopy images and annotations for cell segmentation.

The dataset is described in the publication https://doi.org/10.1038/s41592-019-0658-6. Please cite it if you use this dataset in your research.

  1"""This dataset was part of the HPA Kaggle challenge for protein identification.
  2It contains confocal microscopy images and annotations for cell segmentation.
  3
  4The dataset is described in the publication https://doi.org/10.1038/s41592-019-0658-6.
  5Please cite it if you use this dataset in your research.
  6"""
  7
  8import os
  9import json
 10import shutil
 11from glob import glob
 12from tqdm import tqdm
 13from concurrent import futures
 14from functools import partial
 15from typing import List, Optional, Sequence, Tuple, Union
 16
 17import imageio
 18import numpy as np
 19from skimage import morphology
 20from PIL import Image, ImageDraw
 21from skimage import draw as skimage_draw
 22
 23from torch.utils.data import Dataset, DataLoader
 24
 25import torch_em
 26
 27from .. import util
 28
 29
 30URLS = {
 31    "segmentation": "https://zenodo.org/record/4665863/files/hpa_dataset_v2.zip"
 32}
 33CHECKSUMS = {
 34    "segmentation": "dcd6072293d88d49c71376d3d99f3f4f102e4ee83efb0187faa89c95ec49faa9"
 35}
 36VALID_CHANNELS = ["microtubules", "protein", "nuclei", "er"]
 37
 38
 39def _download_hpa_data(path, name, download):
 40    os.makedirs(path, exist_ok=True)
 41    url = URLS[name]
 42    checksum = CHECKSUMS[name]
 43    zip_path = os.path.join(path, "data.zip")
 44    util.download_source(zip_path, url, download=download, checksum=checksum)
 45    util.unzip(zip_path, path, remove=True)
 46
 47
 48def _load_features(features):
 49    # Loop over list and create simple dictionary & get size of annotations
 50    annot_dict = {}
 51    skipped = []
 52
 53    for feat_idx, feat in enumerate(features):
 54        if feat["geometry"]["type"] not in ["Polygon", "LineString"]:
 55            skipped.append(feat["geometry"]["type"])
 56            continue
 57
 58        # skip empty roi
 59        if len(feat["geometry"]["coordinates"][0]) <= 0:
 60            continue
 61
 62        key_annot = "annot_" + str(feat_idx)
 63        annot_dict[key_annot] = {}
 64        annot_dict[key_annot]["type"] = feat["geometry"]["type"]
 65        annot_dict[key_annot]["pos"] = np.squeeze(
 66            np.asarray(feat["geometry"]["coordinates"])
 67        )
 68        annot_dict[key_annot]["properties"] = feat["properties"]
 69
 70    # print("Skipped geometry type(s):", skipped)
 71    return annot_dict
 72
 73
 74def _generate_binary_masks(annot_dict, shape, erose_size=5, obj_size_rem=500, save_indiv=False):
 75    # Get dimensions of image and created masks of same size
 76    # This we need to save somewhere (e.g. as part of the geojson file?)
 77
 78    # Filled masks and edge mask for polygons
 79    mask_fill = np.zeros(shape, dtype=np.uint8)
 80    mask_edge = np.zeros(shape, dtype=np.uint8)
 81    mask_labels = np.zeros(shape, dtype=np.uint16)
 82
 83    rr_all = []
 84    cc_all = []
 85
 86    if save_indiv is True:
 87        mask_edge_indiv = np.zeros(
 88            (shape[0], shape[1], len(annot_dict)), dtype="bool"
 89        )
 90        mask_fill_indiv = np.zeros(
 91            (shape[0], shape[1], len(annot_dict)), dtype="bool"
 92        )
 93
 94    # Image used to draw lines - for edge mask for freelines
 95    im_freeline = Image.new("1", (shape[1], shape[0]), color=0)
 96    draw = ImageDraw.Draw(im_freeline)
 97
 98    # Loop over all roi
 99    i_roi = 0
100    for roi_key, roi in annot_dict.items():
101        roi_pos = roi["pos"]
102
103        # Check region type
104
105        # freeline - line
106        if roi["type"] == "freeline" or roi["type"] == "LineString":
107
108            # Loop over all pairs of points to draw the line
109
110            for ind in range(roi_pos.shape[0] - 1):
111                line_pos = (
112                    roi_pos[ind, 1],
113                    roi_pos[ind, 0],
114                    roi_pos[ind + 1, 1],
115                    roi_pos[ind + 1, 0],
116                )
117                draw.line(line_pos, fill=1, width=erose_size)
118
119        # freehand - polygon
120        elif (
121            roi["type"] == "freehand"
122            or roi["type"] == "polygon"
123            or roi["type"] == "polyline"
124            or roi["type"] == "Polygon"
125        ):
126
127            # Draw polygon
128            rr, cc = skimage_draw.polygon(
129                [shape[0] - r for r in roi_pos[:, 1]], roi_pos[:, 0]
130            )
131
132            # Make sure it's not outside
133            rr[rr < 0] = 0
134            rr[rr > shape[0] - 1] = shape[0] - 1
135
136            cc[cc < 0] = 0
137            cc[cc > shape[1] - 1] = shape[1] - 1
138
139            # Test if this region has already been added
140            if any(np.array_equal(rr, rr_test) for rr_test in rr_all) and any(
141                np.array_equal(cc, cc_test) for cc_test in cc_all
142            ):
143                # print('Region #{} has already been used'.format(i +
144                # 1))
145                continue
146
147            rr_all.append(rr)
148            cc_all.append(cc)
149
150            # Generate mask
151            mask_fill_roi = np.zeros(shape, dtype=np.uint8)
152            mask_fill_roi[rr, cc] = 1
153
154            # Erode to get cell edge - both arrays are boolean to be used as
155            # index arrays later
156            mask_fill_roi_erode = morphology.binary_erosion(
157                mask_fill_roi, np.ones((erose_size, erose_size))
158            )
159            mask_edge_roi = (
160                mask_fill_roi.astype("int") - mask_fill_roi_erode.astype("int")
161            ).astype("bool")
162
163            # Save array for mask and edge
164            mask_fill[mask_fill_roi > 0] = 1
165            mask_edge[mask_edge_roi] = 1
166            mask_labels[mask_fill_roi > 0] = i_roi + 1
167
168            if save_indiv is True:
169                mask_edge_indiv[:, :, i_roi] = mask_edge_roi.astype("bool")
170                mask_fill_indiv[:, :, i_roi] = mask_fill_roi_erode.astype("bool")
171
172            i_roi = i_roi + 1
173
174        else:
175            roi_type = roi["type"]
176            raise NotImplementedError(
177                f'Mask for roi type "{roi_type}" can not be created'
178            )
179
180    del draw
181
182    # Convert mask from free-lines to numpy array
183    mask_edge_freeline = np.asarray(im_freeline)
184    mask_edge_freeline = mask_edge_freeline.astype("bool")
185
186    # Post-processing of fill and edge mask - if defined
187    mask_dict = {}
188    if np.any(mask_fill):
189
190        # (1) remove edges , (2) remove small  objects
191        mask_fill = mask_fill & ~mask_edge
192        mask_fill = morphology.remove_small_objects(
193            mask_fill.astype("bool"), obj_size_rem
194        )
195
196        # For edge - consider also freeline edge mask
197
198        mask_edge = mask_edge.astype("bool")
199        mask_edge = np.logical_or(mask_edge, mask_edge_freeline)
200
201        # Assign to dictionary for return
202        mask_dict["edge"] = mask_edge
203        mask_dict["fill"] = mask_fill.astype("bool")
204        mask_dict["labels"] = mask_labels.astype("uint16")
205
206        if save_indiv is True:
207            mask_dict["edge_indiv"] = mask_edge_indiv
208            mask_dict["fill_indiv"] = mask_fill_indiv
209        else:
210            mask_dict["edge_indiv"] = np.zeros(shape + (1,), dtype=np.uint8)
211            mask_dict["fill_indiv"] = np.zeros(shape + (1,), dtype=np.uint8)
212
213    # Only edge mask present
214    elif np.any(mask_edge_freeline):
215        mask_dict["edge"] = mask_edge_freeline
216        mask_dict["fill"] = mask_fill.astype("bool")
217        mask_dict["labels"] = mask_labels.astype("uint16")
218
219        mask_dict["edge_indiv"] = np.zeros(shape + (1,), dtype=np.uint8)
220        mask_dict["fill_indiv"] = np.zeros(shape + (1,), dtype=np.uint8)
221
222    else:
223        raise Exception("No mask has been created.")
224
225    return mask_dict
226
227
228# adapted from
229# https://github.com/imjoy-team/kaibu-utils/blob/main/kaibu_utils/__init__.py#L267
230def _get_labels(annotation_file, shape, label="*"):
231    with open(annotation_file) as f:
232        features = json.load(f)["features"]
233    if len(features) == 0:
234        return np.zeros(shape, dtype="uint16")
235
236    annot_dict_all = _load_features(features)
237    annot_types = set(
238        annot_dict_all[k]["properties"].get("label", "default")
239        for k in annot_dict_all.keys()
240    )
241    for annot_type in annot_types:
242        if label and label != "*" and annot_type != label:
243            continue
244        # print("annot_type: ", annot_type)
245        # Filter the annotations by label
246        annot_dict = {
247            k: annot_dict_all[k]
248            for k in annot_dict_all.keys()
249            if label == "*"
250            or annot_dict_all[k]["properties"].get("label", "default") == annot_type
251        }
252        mask_dict = _generate_binary_masks(
253            annot_dict, shape,
254            erose_size=5,
255            obj_size_rem=500,
256            save_indiv=True,
257        )
258        mask = mask_dict["labels"]
259        return mask
260    raise RuntimeError
261
262
263def _process_image(in_folder, out_path, with_labels):
264    import h5py
265
266    # TODO double check the default order and color matching
267    # correspondence to the HPA kaggle data:
268    # microtubules: red
269    # nuclei: blue
270    # er: yellow
271    # protein: green
272    # default order: rgby = micro, prot, nuclei, er
273    raw = np.concatenate([
274        imageio.imread(os.path.join(in_folder, f"{chan}.png"))[None] for chan in VALID_CHANNELS
275    ], axis=0)
276
277    if with_labels:
278        annotation_file = os.path.join(in_folder, "annotation.json")
279        assert os.path.exists(annotation_file), annotation_file
280        labels = _get_labels(annotation_file, raw.shape[1:])
281        assert labels.shape == raw.shape[1:]
282
283    with h5py.File(out_path, "w") as f:
284        f.create_dataset("raw/microtubules", data=raw[0], compression="gzip")
285        f.create_dataset("raw/protein", data=raw[1], compression="gzip")
286        f.create_dataset("raw/nuclei", data=raw[2], compression="gzip")
287        f.create_dataset("raw/er", data=raw[3], compression="gzip")
288        if with_labels:
289            f.create_dataset("labels", data=labels, compression="gzip")
290
291
292def _process_split(root_in, root_out, n_workers, with_labels):
293    os.makedirs(root_out, exist_ok=True)
294    inputs = glob(os.path.join(root_in, "*"))
295    outputs = [os.path.join(root_out, f"{os.path.split(inp)[1]}.h5") for inp in inputs]
296    process = partial(_process_image, with_labels=with_labels)
297    with futures.ProcessPoolExecutor(n_workers) as pp:
298        list(tqdm(pp.map(process, inputs, outputs), total=len(inputs), desc=f"Process data in {root_in}"))
299
300
301# save data as h5 in 4 separate channel raw data and labels extracted from the geo json
302def _process_hpa_data(path, n_workers, remove):
303    in_path = os.path.join(path, "hpa_dataset_v2")
304    assert os.path.exists(in_path), in_path
305    for split in ("train", "test", "valid"):
306        out_split = "val" if split == "valid" else split
307        _process_split(
308            root_in=os.path.join(in_path, split),
309            root_out=os.path.join(path, out_split),
310            n_workers=n_workers,
311            with_labels=(split != "test")
312        )
313    if remove:
314        shutil.rmtree(in_path)
315
316
317def _check_data(path):
318    have_train = len(glob(os.path.join(path, "train", "*.h5"))) == 257
319    have_test = len(glob(os.path.join(path, "test", "*.h5"))) == 10
320    have_val = len(glob(os.path.join(path, "val", "*.h5"))) == 9
321    return have_train and have_test and have_val
322
323
324def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str:
325    """Download the HPA training data.
326
327    Args:
328        path: Filepath to a folder where the downloaded data will be saved.
329        download: Whether to download the data if it is not present.
330        n_workers_preproc: The number of workers to use for preprocessing.
331
332    Returns:
333        The filepath to the training data.
334    """
335    data_is_complete = _check_data(path)
336    if not data_is_complete:
337        _download_hpa_data(path, "segmentation", download)
338        _process_hpa_data(path, n_workers_preproc, remove=True)
339    return path
340
341
342def get_hpa_segmentation_paths(
343    path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8,
344) -> List[str]:
345    """Get paths to the HPA data.
346
347    Args:
348        path: Filepath to a folder where the downloaded data will be saved.
349        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
350        download: Whether to download the data if it is not present.
351        n_workers_preproc: The number of workers to use for preprocessing.
352
353    Returns:
354        List of filepaths to the stored data.
355    """
356    get_hpa_segmentation_data(path, download, n_workers_preproc)
357    paths = glob(os.path.join(path, split, "*.h5"))
358    return paths
359
360
361def get_hpa_segmentation_dataset(
362    path: Union[os.PathLike, str],
363    split: str,
364    patch_shape: Tuple[int, int],
365    offsets: Optional[List[List[int]]] = None,
366    boundaries: bool = False,
367    binary: bool = False,
368    channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"],
369    download: bool = False,
370    n_workers_preproc: int = 8,
371    **kwargs
372) -> Dataset:
373    """Get the HPA dataset for segmenting cells in confocal microscopy.
374
375    Args:
376        path: Filepath to a folder where the downloaded data will be saved.
377        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
378        patch_shape: The patch shape to use for training.
379        offsets: Offset values for affinity computation used as target.
380        boundaries: Whether to compute boundaries as the target.
381        binary: Whether to use a binary segmentation target.
382        channels: The image channels to extract. Available channels are
383            'microtubules', 'protein', 'nuclei' or 'er'.
384        download: Whether to download the data if it is not present.
385        n_workers_preproc: The number of workers to use for preprocessing.
386        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
387
388    Returns:
389       The segmentation dataset.
390    """
391    assert isinstance(channels, list), "The 'channels' argument expects the desired channel(s) in a list."
392    for chan in channels:
393        if chan not in VALID_CHANNELS:
394            raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.")
395
396    kwargs, _ = util.add_instance_label_transform(
397        kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
398    )
399    kwargs = util.update_kwargs(kwargs, "ndim", 2)
400    kwargs = util.update_kwargs(kwargs, "with_channels", True)
401
402    paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc)
403
404    return torch_em.default_segmentation_dataset(
405        raw_paths=paths,
406        raw_key=[f"raw/{chan}" for chan in channels],
407        label_paths=paths,
408        label_key="labels",
409        patch_shape=patch_shape,
410        **kwargs
411    )
412
413
414def get_hpa_segmentation_loader(
415    path: Union[os.PathLike, str],
416    split: str,
417    patch_shape: Tuple[int, int],
418    batch_size: int,
419    offsets: Optional[List[List[int]]] = None,
420    boundaries: bool = False,
421    binary: bool = False,
422    channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"],
423    download: bool = False,
424    n_workers_preproc: int = 8,
425    **kwargs
426) -> DataLoader:
427    """Get the HPA dataloader for segmenting cells in confocal microscopy.
428
429    Args:
430        path: Filepath to a folder where the downloaded data will be saved.
431        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
432        patch_shape: The patch shape to use for training.
433        batch_size: The batch size for training.
434        offsets: Offset values for affinity computation used as target.
435        boundaries: Whether to compute boundaries as the target.
436        binary: Whether to use a binary segmentation target.
437        channels: The image channels to extract. Available channels are
438            'microtubules', 'protein', 'nuclei' or 'er'.
439        download: Whether to download the data if it is not present.
440        n_workers_preproc: The number of workers to use for preprocessing.
441        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
442
443    Returns:
444       The DataLoader.
445    """
446    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
447    dataset = get_hpa_segmentation_dataset(
448        path, split, patch_shape,
449        offsets=offsets, boundaries=boundaries, binary=binary,
450        channels=channels, download=download, n_workers_preproc=n_workers_preproc,
451        **ds_kwargs
452    )
453    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'segmentation': 'https://zenodo.org/record/4665863/files/hpa_dataset_v2.zip'}
CHECKSUMS = {'segmentation': 'dcd6072293d88d49c71376d3d99f3f4f102e4ee83efb0187faa89c95ec49faa9'}
VALID_CHANNELS = ['microtubules', 'protein', 'nuclei', 'er']
def get_hpa_segmentation_data( path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str:
325def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str:
326    """Download the HPA training data.
327
328    Args:
329        path: Filepath to a folder where the downloaded data will be saved.
330        download: Whether to download the data if it is not present.
331        n_workers_preproc: The number of workers to use for preprocessing.
332
333    Returns:
334        The filepath to the training data.
335    """
336    data_is_complete = _check_data(path)
337    if not data_is_complete:
338        _download_hpa_data(path, "segmentation", download)
339        _process_hpa_data(path, n_workers_preproc, remove=True)
340    return path

Download the HPA training data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • download: Whether to download the data if it is not present.
  • n_workers_preproc: The number of workers to use for preprocessing.
Returns:

The filepath to the training data.

def get_hpa_segmentation_paths( path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8) -> List[str]:
343def get_hpa_segmentation_paths(
344    path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8,
345) -> List[str]:
346    """Get paths to the HPA data.
347
348    Args:
349        path: Filepath to a folder where the downloaded data will be saved.
350        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
351        download: Whether to download the data if it is not present.
352        n_workers_preproc: The number of workers to use for preprocessing.
353
354    Returns:
355        List of filepaths to the stored data.
356    """
357    get_hpa_segmentation_data(path, download, n_workers_preproc)
358    paths = glob(os.path.join(path, split, "*.h5"))
359    return paths

Get paths to the HPA data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
  • download: Whether to download the data if it is not present.
  • n_workers_preproc: The number of workers to use for preprocessing.
Returns:

List of filepaths to the stored data.

def get_hpa_segmentation_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], offsets: Optional[List[List[int]]] = None, boundaries: bool = False, binary: bool = False, channels: Sequence[str] = ['microtubules', 'protein', 'nuclei', 'er'], download: bool = False, n_workers_preproc: int = 8, **kwargs) -> torch.utils.data.dataset.Dataset:
362def get_hpa_segmentation_dataset(
363    path: Union[os.PathLike, str],
364    split: str,
365    patch_shape: Tuple[int, int],
366    offsets: Optional[List[List[int]]] = None,
367    boundaries: bool = False,
368    binary: bool = False,
369    channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"],
370    download: bool = False,
371    n_workers_preproc: int = 8,
372    **kwargs
373) -> Dataset:
374    """Get the HPA dataset for segmenting cells in confocal microscopy.
375
376    Args:
377        path: Filepath to a folder where the downloaded data will be saved.
378        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
379        patch_shape: The patch shape to use for training.
380        offsets: Offset values for affinity computation used as target.
381        boundaries: Whether to compute boundaries as the target.
382        binary: Whether to use a binary segmentation target.
383        channels: The image channels to extract. Available channels are
384            'microtubules', 'protein', 'nuclei' or 'er'.
385        download: Whether to download the data if it is not present.
386        n_workers_preproc: The number of workers to use for preprocessing.
387        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
388
389    Returns:
390       The segmentation dataset.
391    """
392    assert isinstance(channels, list), "The 'channels' argument expects the desired channel(s) in a list."
393    for chan in channels:
394        if chan not in VALID_CHANNELS:
395            raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.")
396
397    kwargs, _ = util.add_instance_label_transform(
398        kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
399    )
400    kwargs = util.update_kwargs(kwargs, "ndim", 2)
401    kwargs = util.update_kwargs(kwargs, "with_channels", True)
402
403    paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc)
404
405    return torch_em.default_segmentation_dataset(
406        raw_paths=paths,
407        raw_key=[f"raw/{chan}" for chan in channels],
408        label_paths=paths,
409        label_key="labels",
410        patch_shape=patch_shape,
411        **kwargs
412    )

Get the HPA dataset for segmenting cells in confocal microscopy.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
  • patch_shape: The patch shape to use for training.
  • offsets: Offset values for affinity computation used as target.
  • boundaries: Whether to compute boundaries as the target.
  • binary: Whether to use a binary segmentation target.
  • channels: The image channels to extract. Available channels are 'microtubules', 'protein', 'nuclei' or 'er'.
  • download: Whether to download the data if it is not present.
  • n_workers_preproc: The number of workers to use for preprocessing.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_hpa_segmentation_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, offsets: Optional[List[List[int]]] = None, boundaries: bool = False, binary: bool = False, channels: Sequence[str] = ['microtubules', 'protein', 'nuclei', 'er'], download: bool = False, n_workers_preproc: int = 8, **kwargs) -> torch.utils.data.dataloader.DataLoader:
415def get_hpa_segmentation_loader(
416    path: Union[os.PathLike, str],
417    split: str,
418    patch_shape: Tuple[int, int],
419    batch_size: int,
420    offsets: Optional[List[List[int]]] = None,
421    boundaries: bool = False,
422    binary: bool = False,
423    channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"],
424    download: bool = False,
425    n_workers_preproc: int = 8,
426    **kwargs
427) -> DataLoader:
428    """Get the HPA dataloader for segmenting cells in confocal microscopy.
429
430    Args:
431        path: Filepath to a folder where the downloaded data will be saved.
432        split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
433        patch_shape: The patch shape to use for training.
434        batch_size: The batch size for training.
435        offsets: Offset values for affinity computation used as target.
436        boundaries: Whether to compute boundaries as the target.
437        binary: Whether to use a binary segmentation target.
438        channels: The image channels to extract. Available channels are
439            'microtubules', 'protein', 'nuclei' or 'er'.
440        download: Whether to download the data if it is not present.
441        n_workers_preproc: The number of workers to use for preprocessing.
442        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
443
444    Returns:
445       The DataLoader.
446    """
447    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
448    dataset = get_hpa_segmentation_dataset(
449        path, split, patch_shape,
450        offsets=offsets, boundaries=boundaries, binary=binary,
451        channels=channels, download=download, n_workers_preproc=n_workers_preproc,
452        **ds_kwargs
453    )
454    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the HPA dataloader for segmenting cells in confocal microscopy.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The split for the dataset. Available splits are 'train', 'val' or 'test'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • offsets: Offset values for affinity computation used as target.
  • boundaries: Whether to compute boundaries as the target.
  • binary: Whether to use a binary segmentation target.
  • channels: The image channels to extract. Available channels are 'microtubules', 'protein', 'nuclei' or 'er'.
  • download: Whether to download the data if it is not present.
  • n_workers_preproc: The number of workers to use for preprocessing.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.