torch_em.data.datasets.light_microscopy.tissuenet

The TissueNet dataset contains annotations for cell segmentation in microscopy images of different tissue types.

This dataset is from the publication https://doi.org/10.1038/s41587-021-01094-0. Please cite it if you use this dataset for your research.

This dataset cannot be downloaded automatically, please visit https://datasets.deepcell.org/data and download it yourself.

  1"""The TissueNet dataset contains annotations for cell segmentation in microscopy images of different tissue types.
  2
  3This dataset is from the publication https://doi.org/10.1038/s41587-021-01094-0.
  4Please cite it if you use this dataset for your research.
  5
  6This dataset cannot be downloaded automatically, please visit https://datasets.deepcell.org/data
  7and download it yourself.
  8"""
  9
 10import os
 11from glob import glob
 12from typing import Tuple, Union
 13
 14import numpy as np
 15import pandas as pd
 16import torch_em
 17import z5py
 18
 19from tqdm import tqdm
 20from torch.utils.data import Dataset, DataLoader
 21from .. import util
 22
 23
 24def _create_split(path, split):
 25    split_file = os.path.join(path, f"tissuenet_v1.1_{split}.npz")
 26    split_folder = os.path.join(path, split)
 27    os.makedirs(split_folder, exist_ok=True)
 28    data = np.load(split_file, allow_pickle=True)
 29
 30    x, y = data["X"], data["y"]
 31    metadata = data["meta"]
 32    metadata = pd.DataFrame(metadata[1:], columns=metadata[0])
 33
 34    for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"):
 35        out_path = os.path.join(split_folder, f"image_{i:04}.zarr")
 36        nucleus_channel = im[..., 0]
 37        cell_channel = im[..., 1]
 38        rgb = np.stack([np.zeros_like(nucleus_channel), cell_channel, nucleus_channel])
 39        chunks = cell_channel.shape
 40        with z5py.File(out_path, "a") as f:
 41
 42            f.create_dataset("raw/nucleus", data=im[..., 0], compression="gzip", chunks=chunks)
 43            f.create_dataset("raw/cell", data=cell_channel, compression="gzip", chunks=chunks)
 44            f.create_dataset("raw/rgb", data=rgb, compression="gzip", chunks=(3,) + chunks)
 45
 46            # the switch 0<->1 is intentional, the data format is chaotic...
 47            f.create_dataset("labels/nucleus", data=label[..., 1], compression="gzip", chunks=chunks)
 48            f.create_dataset("labels/cell", data=label[..., 0], compression="gzip", chunks=chunks)
 49    os.remove(split_file)
 50
 51
 52def _create_dataset(path, zip_path):
 53    util.unzip(zip_path, path, remove=False)
 54    splits = ["train", "val", "test"]
 55    assert all([os.path.exists(os.path.join(path, f"tissuenet_v1.1_{split}.npz")) for split in splits])
 56    for split in splits:
 57        _create_split(path, split)
 58
 59
 60def get_tissuenet_dataset(
 61    path: Union[os.PathLike, str],
 62    split: str,
 63    patch_shape: Tuple[int, int],
 64    raw_channel: str,
 65    label_channel: str,
 66    download: bool = False,
 67    **kwargs
 68) -> Dataset:
 69    """Get the TissueNet dataset for segmenting cells in microscopy tissue images.
 70
 71    Args:
 72        path: Filepath to a folder where the downloaded data will be saved.
 73        split: The data split to use. Either 'train', 'val' or 'test'.
 74        patch_shape: The patch shape to use for training.
 75        raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
 76        label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
 77        download: Whether to download the data if it is not present.
 78        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 79
 80    Returns:
 81        The segmentation dataset.
 82    """
 83    assert raw_channel in ("nucleus", "cell", "rgb")
 84    assert label_channel in ("nucleus", "cell")
 85
 86    splits = ["train", "val", "test"]
 87    assert split in splits
 88
 89    # check if the dataset exists already
 90    zip_path = os.path.join(path, "tissuenet_v1.1.zip")
 91    if all([os.path.exists(os.path.join(path, split)) for split in splits]):  # yes it does
 92        pass
 93    elif os.path.exists(zip_path):  # no it does not, but we have the zip there and can unpack it
 94        _create_dataset(path, zip_path)
 95    else:
 96        raise RuntimeError(
 97            "We do not support automatic download for the tissuenet datasets yet."
 98            f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}"
 99        )
100
101    split_folder = os.path.join(path, split)
102    assert os.path.exists(split_folder)
103    data_path = glob(os.path.join(split_folder, "*.zarr"))
104    assert len(data_path) > 0
105
106    raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}"
107
108    with_channels = True if raw_channel == "rgb" else False
109    kwargs = util.update_kwargs(kwargs, "with_channels", with_channels)
110    kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True)
111    kwargs = util.update_kwargs(kwargs, "ndim", 2)
112
113    return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)
114
115
116# TODO enable loading specific tissue types etc. (from the 'meta' attributes)
117def get_tissuenet_loader(
118    path: Union[os.PathLike, str],
119    split: str,
120    patch_shape: Tuple[int, int],
121    batch_size: int,
122    raw_channel: str,
123    label_channel: str,
124    download: bool = False,
125    **kwargs
126) -> DataLoader:
127    """Get the TissueNet dataloader for segmenting cells in microscopy tissue images.
128
129    Args:
130        path: Filepath to a folder where the downloaded data will be saved.
131        split: The data split to use. Either 'train', 'val' or 'test'.
132        patch_shape: The patch shape to use for training.
133        batch_size: The batch size for training.
134        raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
135        label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
136        download: Whether to download the data if it is not present.
137        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
138
139    Returns:
140        The DataLoader.
141    """
142    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
143    dataset = get_tissuenet_dataset(
144        path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs
145    )
146    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
147    return loader
def get_tissuenet_dataset( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], raw_channel: str, label_channel: str, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
 61def get_tissuenet_dataset(
 62    path: Union[os.PathLike, str],
 63    split: str,
 64    patch_shape: Tuple[int, int],
 65    raw_channel: str,
 66    label_channel: str,
 67    download: bool = False,
 68    **kwargs
 69) -> Dataset:
 70    """Get the TissueNet dataset for segmenting cells in microscopy tissue images.
 71
 72    Args:
 73        path: Filepath to a folder where the downloaded data will be saved.
 74        split: The data split to use. Either 'train', 'val' or 'test'.
 75        patch_shape: The patch shape to use for training.
 76        raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
 77        label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
 78        download: Whether to download the data if it is not present.
 79        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 80
 81    Returns:
 82        The segmentation dataset.
 83    """
 84    assert raw_channel in ("nucleus", "cell", "rgb")
 85    assert label_channel in ("nucleus", "cell")
 86
 87    splits = ["train", "val", "test"]
 88    assert split in splits
 89
 90    # check if the dataset exists already
 91    zip_path = os.path.join(path, "tissuenet_v1.1.zip")
 92    if all([os.path.exists(os.path.join(path, split)) for split in splits]):  # yes it does
 93        pass
 94    elif os.path.exists(zip_path):  # no it does not, but we have the zip there and can unpack it
 95        _create_dataset(path, zip_path)
 96    else:
 97        raise RuntimeError(
 98            "We do not support automatic download for the tissuenet datasets yet."
 99            f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}"
100        )
101
102    split_folder = os.path.join(path, split)
103    assert os.path.exists(split_folder)
104    data_path = glob(os.path.join(split_folder, "*.zarr"))
105    assert len(data_path) > 0
106
107    raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}"
108
109    with_channels = True if raw_channel == "rgb" else False
110    kwargs = util.update_kwargs(kwargs, "with_channels", with_channels)
111    kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True)
112    kwargs = util.update_kwargs(kwargs, "ndim", 2)
113
114    return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs)

Get the TissueNet dataset for segmenting cells in microscopy tissue images.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • patch_shape: The patch shape to use for training.
  • raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
  • label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_tissuenet_loader( path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], batch_size: int, raw_channel: str, label_channel: str, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
118def get_tissuenet_loader(
119    path: Union[os.PathLike, str],
120    split: str,
121    patch_shape: Tuple[int, int],
122    batch_size: int,
123    raw_channel: str,
124    label_channel: str,
125    download: bool = False,
126    **kwargs
127) -> DataLoader:
128    """Get the TissueNet dataloader for segmenting cells in microscopy tissue images.
129
130    Args:
131        path: Filepath to a folder where the downloaded data will be saved.
132        split: The data split to use. Either 'train', 'val' or 'test'.
133        patch_shape: The patch shape to use for training.
134        batch_size: The batch size for training.
135        raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
136        label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
137        download: Whether to download the data if it is not present.
138        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
139
140    Returns:
141        The DataLoader.
142    """
143    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
144    dataset = get_tissuenet_dataset(
145        path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs
146    )
147    loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
148    return loader

Get the TissueNet dataloader for segmenting cells in microscopy tissue images.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • patch_shape: The patch shape to use for training.
  • batch_size: The batch size for training.
  • raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'.
  • label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.