torch_em.data.datasets.light_microscopy.cshaper

The CShaper dataset contains 3D fluorescence microscopy images of Caenorhabditis elegans early embryos with cell instance segmentation annotations.

The dataset is organised into training and evaluation splits:

  • Training: Sample01, Sample02 (27 timepoints each)
  • Evaluation: Sample02, Sample03, Sample04 (7 timepoints each)

Each timepoint is a separate 3D NIfTI volume (.nii.gz):

  • Raw membrane images: RawMemb/{sample}_{tp}_rawMemb.nii.gz
  • Cell segmentation: SegCell/{sample}_{tp}_segCell.nii.gz

NOTE: The data must be downloaded manually. Download the zip from the SharePoint link provided at https://doi.org/10.6084/m9.figshare.12839315 and place it as {path}/OneDrive.zip (or whatever filename it downloads as).

The dataset is from the publication https://doi.org/10.1038/s41467-020-19863-x. Please cite it if you use this dataset in your research.

  1"""The CShaper dataset contains 3D fluorescence microscopy images of Caenorhabditis
  2elegans early embryos with cell instance segmentation annotations.
  3
  4The dataset is organised into training and evaluation splits:
  5- Training: Sample01, Sample02 (27 timepoints each)
  6- Evaluation: Sample02, Sample03, Sample04 (7 timepoints each)
  7
  8Each timepoint is a separate 3D NIfTI volume (.nii.gz):
  9- Raw membrane images: RawMemb/{sample}_{tp}_rawMemb.nii.gz
 10- Cell segmentation: SegCell/{sample}_{tp}_segCell.nii.gz
 11
 12NOTE: The data must be downloaded manually. Download the zip from the SharePoint link
 13provided at https://doi.org/10.6084/m9.figshare.12839315 and place it as
 14`{path}/OneDrive.zip` (or whatever filename it downloads as).
 15
 16The dataset is from the publication https://doi.org/10.1038/s41467-020-19863-x.
 17Please cite it if you use this dataset in your research.
 18"""
 19
 20import os
 21from glob import glob
 22from natsort import natsorted
 23from typing import List, Literal, Optional, Tuple, Union
 24
 25from torch.utils.data import Dataset, DataLoader
 26
 27import torch_em
 28
 29from .. import util
 30
 31
 32# Root path inside the zip after extraction
 33_ZIP_ROOT = "CShaper Supplementary Data/DMapNet Training and Evaluation"
 34
 35TRAIN_SAMPLES = ["Sample01", "Sample02"]
 36EVAL_SAMPLES = ["Sample02", "Sample03", "Sample04"]
 37
 38
 39def get_cshaper_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 40    """Extract the CShaper dataset zip.
 41
 42    NOTE: The zip must be downloaded manually from the SharePoint link at
 43    https://doi.org/10.6084/m9.figshare.12839315 and placed inside `path`.
 44    Any zip file found in `path` will be extracted automatically.
 45
 46    Args:
 47        path: Filepath to a folder containing the downloaded CShaper zip.
 48        download: Ignored (manual download required).
 49
 50    Returns:
 51        The filepath to the extracted data root directory.
 52    """
 53    data_dir = os.path.join(path, _ZIP_ROOT)
 54    if os.path.exists(data_dir):
 55        return data_dir
 56
 57    # Find any zip in path
 58    zips = glob(os.path.join(path, "*.zip"))
 59    if not zips:
 60        raise RuntimeError(
 61            f"No zip file found in {path}. "
 62            "Please download the CShaper data manually from the SharePoint link at "
 63            "https://doi.org/10.6084/m9.figshare.12839315 and place the zip in `path`."
 64        )
 65
 66    util.unzip(zips[0], path)
 67    return data_dir
 68
 69
 70def _convert_to_h5(data_dir: str, split: str) -> str:
 71    """Convert NIfTI timepoint files to per-timepoint HDF5 files.
 72
 73    Args:
 74        data_dir: The extracted CShaper root directory.
 75        split: "train" or "val".
 76
 77    Returns:
 78        The directory containing the converted HDF5 files.
 79    """
 80    try:
 81        import nibabel as nib
 82    except ImportError:
 83        raise RuntimeError(
 84            "The 'nibabel' package is required to read CShaper NIfTI files. "
 85            "Install with: pip install nibabel"
 86        )
 87    import h5py
 88
 89    split_subdir = "TrainingData" if split == "train" else "EvaluationData"
 90    split_dir = os.path.join(data_dir, split_subdir)
 91
 92    h5_dir = os.path.join(data_dir, f"h5_{split}")
 93    if os.path.exists(h5_dir) and len(glob(os.path.join(h5_dir, "*.h5"))) > 0:
 94        return h5_dir
 95    os.makedirs(h5_dir, exist_ok=True)
 96
 97    sample_dirs = natsorted([
 98        d for d in glob(os.path.join(split_dir, "*/")) if os.path.isdir(d)
 99    ])
100
101    for sample_dir in sample_dirs:
102        sample_name = os.path.basename(sample_dir.rstrip("/"))
103        raw_files = natsorted(glob(os.path.join(sample_dir, "RawMemb", "*.nii.gz")))
104        seg_dir = os.path.join(sample_dir, "SegCell")
105
106        for raw_path in raw_files:
107            # e.g. Sample01_030_rawMemb.nii.gz → Sample01_030
108            basename = os.path.basename(raw_path)
109            tp_stem = basename.replace("_rawMemb.nii.gz", "")
110            h5_path = os.path.join(h5_dir, f"{tp_stem}.h5")
111
112            if os.path.exists(h5_path):
113                continue
114
115            seg_path = os.path.join(seg_dir, f"{tp_stem}_segCell.nii.gz")
116            if not os.path.exists(seg_path):
117                continue
118
119            raw_vol = nib.load(raw_path).get_fdata().astype("float32")
120            seg_vol = nib.load(seg_path).get_fdata().astype("int32")
121
122            with h5py.File(h5_path, "w") as f:
123                f.create_dataset("raw", data=raw_vol, compression="gzip")
124                f.create_dataset("labels", data=seg_vol, compression="gzip")
125
126    return h5_dir
127
128
129def get_cshaper_paths(
130    path: Union[os.PathLike, str],
131    split: Literal["train", "val"] = "train",
132    samples: Optional[List[str]] = None,
133    download: bool = False,
134) -> Tuple[List[str], List[str]]:
135    """Get paths to the CShaper data.
136
137    Args:
138        path: Filepath to a folder containing the downloaded CShaper zip.
139        split: The data split to use. Either "train" (Sample01, Sample02) or
140            "val" (Sample02, Sample03, Sample04).
141        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
142            If None, all samples for the split are used.
143        download: Ignored (manual download required).
144
145    Returns:
146        List of filepaths for the HDF5 image data (key: "raw").
147        List of filepaths for the HDF5 label data (key: "labels").
148    """
149    if split not in ("train", "val"):
150        raise ValueError(f"Invalid split '{split}'. Choose 'train' or 'val'.")
151
152    data_dir = get_cshaper_data(path, download)
153    h5_dir = _convert_to_h5(data_dir, split)
154
155    h5_files = natsorted(glob(os.path.join(h5_dir, "*.h5")))
156
157    if len(h5_files) == 0:
158        raise RuntimeError(f"No HDF5 files found in {h5_dir}. Check the dataset structure.")
159
160    if samples is not None:
161        h5_files = [p for p in h5_files if any(os.path.basename(p).startswith(s) for s in samples)]
162
163    return h5_files, h5_files
164
165
166def get_cshaper_dataset(
167    path: Union[os.PathLike, str],
168    patch_shape: Tuple[int, ...],
169    split: Literal["train", "val"] = "train",
170    samples: Optional[List[str]] = None,
171    raw_key: str = "raw",
172    label_key: str = "labels",
173    download: bool = False,
174    **kwargs,
175) -> Dataset:
176    """Get the CShaper dataset for C. elegans embryo cell segmentation.
177
178    Args:
179        path: Filepath to a folder containing the downloaded CShaper zip.
180        patch_shape: The patch shape to use for training.
181        split: The data split to use. Either "train" or "val".
182        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
183        raw_key: The HDF5 key for raw image data.
184        label_key: The HDF5 key for label data.
185        download: Ignored (manual download required).
186        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
187
188    Returns:
189        The segmentation dataset.
190    """
191    raw_paths, label_paths = get_cshaper_paths(path, split, samples, download)
192
193    return torch_em.default_segmentation_dataset(
194        raw_paths=raw_paths,
195        raw_key=raw_key,
196        label_paths=label_paths,
197        label_key=label_key,
198        patch_shape=patch_shape,
199        **kwargs,
200    )
201
202
203def get_cshaper_loader(
204    path: Union[os.PathLike, str],
205    batch_size: int,
206    patch_shape: Tuple[int, ...],
207    split: Literal["train", "val"] = "train",
208    samples: Optional[List[str]] = None,
209    raw_key: str = "raw",
210    label_key: str = "labels",
211    download: bool = False,
212    **kwargs,
213) -> DataLoader:
214    """Get the CShaper dataloader for C. elegans embryo cell segmentation.
215
216    Args:
217        path: Filepath to a folder containing the downloaded CShaper zip.
218        batch_size: The batch size for training.
219        patch_shape: The patch shape to use for training.
220        split: The data split to use. Either "train" or "val".
221        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
222        raw_key: The HDF5 key for raw image data.
223        label_key: The HDF5 key for label data.
224        download: Ignored (manual download required).
225        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
226
227    Returns:
228        The DataLoader.
229    """
230    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
231    dataset = get_cshaper_dataset(path, patch_shape, split, samples, raw_key, label_key, download, **ds_kwargs)
232    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
TRAIN_SAMPLES = ['Sample01', 'Sample02']
EVAL_SAMPLES = ['Sample02', 'Sample03', 'Sample04']
def get_cshaper_data(path: Union[os.PathLike, str], download: bool = False) -> str:
40def get_cshaper_data(path: Union[os.PathLike, str], download: bool = False) -> str:
41    """Extract the CShaper dataset zip.
42
43    NOTE: The zip must be downloaded manually from the SharePoint link at
44    https://doi.org/10.6084/m9.figshare.12839315 and placed inside `path`.
45    Any zip file found in `path` will be extracted automatically.
46
47    Args:
48        path: Filepath to a folder containing the downloaded CShaper zip.
49        download: Ignored (manual download required).
50
51    Returns:
52        The filepath to the extracted data root directory.
53    """
54    data_dir = os.path.join(path, _ZIP_ROOT)
55    if os.path.exists(data_dir):
56        return data_dir
57
58    # Find any zip in path
59    zips = glob(os.path.join(path, "*.zip"))
60    if not zips:
61        raise RuntimeError(
62            f"No zip file found in {path}. "
63            "Please download the CShaper data manually from the SharePoint link at "
64            "https://doi.org/10.6084/m9.figshare.12839315 and place the zip in `path`."
65        )
66
67    util.unzip(zips[0], path)
68    return data_dir

Extract the CShaper dataset zip.

NOTE: The zip must be downloaded manually from the SharePoint link at https://doi.org/10.6084/m9.figshare.12839315 and placed inside path. Any zip file found in path will be extracted automatically.

Arguments:
  • path: Filepath to a folder containing the downloaded CShaper zip.
  • download: Ignored (manual download required).
Returns:

The filepath to the extracted data root directory.

def get_cshaper_paths( path: Union[os.PathLike, str], split: Literal['train', 'val'] = 'train', samples: Optional[List[str]] = None, download: bool = False) -> Tuple[List[str], List[str]]:
130def get_cshaper_paths(
131    path: Union[os.PathLike, str],
132    split: Literal["train", "val"] = "train",
133    samples: Optional[List[str]] = None,
134    download: bool = False,
135) -> Tuple[List[str], List[str]]:
136    """Get paths to the CShaper data.
137
138    Args:
139        path: Filepath to a folder containing the downloaded CShaper zip.
140        split: The data split to use. Either "train" (Sample01, Sample02) or
141            "val" (Sample02, Sample03, Sample04).
142        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
143            If None, all samples for the split are used.
144        download: Ignored (manual download required).
145
146    Returns:
147        List of filepaths for the HDF5 image data (key: "raw").
148        List of filepaths for the HDF5 label data (key: "labels").
149    """
150    if split not in ("train", "val"):
151        raise ValueError(f"Invalid split '{split}'. Choose 'train' or 'val'.")
152
153    data_dir = get_cshaper_data(path, download)
154    h5_dir = _convert_to_h5(data_dir, split)
155
156    h5_files = natsorted(glob(os.path.join(h5_dir, "*.h5")))
157
158    if len(h5_files) == 0:
159        raise RuntimeError(f"No HDF5 files found in {h5_dir}. Check the dataset structure.")
160
161    if samples is not None:
162        h5_files = [p for p in h5_files if any(os.path.basename(p).startswith(s) for s in samples)]
163
164    return h5_files, h5_files

Get paths to the CShaper data.

Arguments:
  • path: Filepath to a folder containing the downloaded CShaper zip.
  • split: The data split to use. Either "train" (Sample01, Sample02) or "val" (Sample02, Sample03, Sample04).
  • samples: Optional list of sample names to restrict to (e.g., ["Sample01"]). If None, all samples for the split are used.
  • download: Ignored (manual download required).
Returns:

List of filepaths for the HDF5 image data (key: "raw"). List of filepaths for the HDF5 label data (key: "labels").

def get_cshaper_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val'] = 'train', samples: Optional[List[str]] = None, raw_key: str = 'raw', label_key: str = 'labels', download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
167def get_cshaper_dataset(
168    path: Union[os.PathLike, str],
169    patch_shape: Tuple[int, ...],
170    split: Literal["train", "val"] = "train",
171    samples: Optional[List[str]] = None,
172    raw_key: str = "raw",
173    label_key: str = "labels",
174    download: bool = False,
175    **kwargs,
176) -> Dataset:
177    """Get the CShaper dataset for C. elegans embryo cell segmentation.
178
179    Args:
180        path: Filepath to a folder containing the downloaded CShaper zip.
181        patch_shape: The patch shape to use for training.
182        split: The data split to use. Either "train" or "val".
183        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
184        raw_key: The HDF5 key for raw image data.
185        label_key: The HDF5 key for label data.
186        download: Ignored (manual download required).
187        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
188
189    Returns:
190        The segmentation dataset.
191    """
192    raw_paths, label_paths = get_cshaper_paths(path, split, samples, download)
193
194    return torch_em.default_segmentation_dataset(
195        raw_paths=raw_paths,
196        raw_key=raw_key,
197        label_paths=label_paths,
198        label_key=label_key,
199        patch_shape=patch_shape,
200        **kwargs,
201    )

Get the CShaper dataset for C. elegans embryo cell segmentation.

Arguments:
  • path: Filepath to a folder containing the downloaded CShaper zip.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either "train" or "val".
  • samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
  • raw_key: The HDF5 key for raw image data.
  • label_key: The HDF5 key for label data.
  • download: Ignored (manual download required).
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_cshaper_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val'] = 'train', samples: Optional[List[str]] = None, raw_key: str = 'raw', label_key: str = 'labels', download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
204def get_cshaper_loader(
205    path: Union[os.PathLike, str],
206    batch_size: int,
207    patch_shape: Tuple[int, ...],
208    split: Literal["train", "val"] = "train",
209    samples: Optional[List[str]] = None,
210    raw_key: str = "raw",
211    label_key: str = "labels",
212    download: bool = False,
213    **kwargs,
214) -> DataLoader:
215    """Get the CShaper dataloader for C. elegans embryo cell segmentation.
216
217    Args:
218        path: Filepath to a folder containing the downloaded CShaper zip.
219        batch_size: The batch size for training.
220        patch_shape: The patch shape to use for training.
221        split: The data split to use. Either "train" or "val".
222        samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
223        raw_key: The HDF5 key for raw image data.
224        label_key: The HDF5 key for label data.
225        download: Ignored (manual download required).
226        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
227
228    Returns:
229        The DataLoader.
230    """
231    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
232    dataset = get_cshaper_dataset(path, patch_shape, split, samples, raw_key, label_key, download, **ds_kwargs)
233    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the CShaper dataloader for C. elegans embryo cell segmentation.

Arguments:
  • path: Filepath to a folder containing the downloaded CShaper zip.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either "train" or "val".
  • samples: Optional list of sample names to restrict to (e.g., ["Sample01"]).
  • raw_key: The HDF5 key for raw image data.
  • label_key: The HDF5 key for label data.
  • download: Ignored (manual download required).
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.