torch_em.data.datasets.light_microscopy.xpress

  1"""The XPRESS dataset contains volumetric microscopy data with voxel-wise labels.
  2
  3The training data is hosted at:
  4- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5
  5- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5
  6"""
  7
  8import os
  9from typing import Optional, Tuple, Union
 10
 11from torch.utils.data import Dataset, DataLoader
 12
 13import torch_em
 14
 15from .. import util
 16
 17
 18URLS = {
 19    "raw": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5",
 20    "labels": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5",
 21}
 22
 23
 24def _default_chunks(shape):
 25    # Simple heuristic: chunk along z and limit chunk extents to 64.
 26    return tuple(min(64, int(s)) for s in shape)
 27
 28
 29
 30def _merge_to_single_h5(raw_path: Union[os.PathLike, str], label_path: Union[os.PathLike, str], out_path: str):
 31    if os.path.exists(out_path):
 32        return out_path
 33
 34    import h5py
 35    import numpy as np
 36
 37    with h5py.File(raw_path, "r") as fr, h5py.File(label_path, "r") as fl, h5py.File(out_path, "w") as fo:
 38        raw_ds_in = fr["volumes/raw"]
 39        labels_ds_in = fl["volumes/labels"]
 40
 41        raw_resolution = np.array(raw_ds_in.attrs.get("resolution", [1, 1, 1]))
 42        label_offset = np.array(labels_ds_in.attrs.get("offset", [0, 0, 0]))
 43
 44        # Convert the label offset from world coordinates to voxel coordinates in the raw volume.
 45        voxel_offset = (label_offset / raw_resolution).astype(int)
 46        labels_arr = labels_ds_in[...]
 47
 48        # Crop the raw with extra context (128 px padding per side) around the labeled region.
 49        context_pad = 128
 50        raw_shape = np.array(raw_ds_in.shape)
 51        starts = np.clip(voxel_offset - context_pad, 0, raw_shape)
 52        ends = np.clip(voxel_offset + np.array(labels_arr.shape) + context_pad, 0, raw_shape)
 53
 54        raw_slices = tuple(slice(int(s), int(e)) for s, e in zip(starts, ends))
 55        raw_arr = raw_ds_in[raw_slices]
 56
 57        # Place labels inside a zero-padded volume matching the (padded) raw crop.
 58        label_insert_offset = voxel_offset - starts
 59        padded_labels = np.zeros(raw_arr.shape, dtype="int64")
 60        label_slices = tuple(
 61            slice(int(o), int(o) + s) for o, s in zip(label_insert_offset, labels_arr.shape)
 62        )
 63        padded_labels[label_slices] = labels_arr
 64
 65        chunks = _default_chunks(raw_arr.shape)
 66
 67        fo.create_dataset("raw", data=raw_arr, chunks=chunks, compression="gzip", compression_opts=4)
 68        fo.create_dataset("labels", data=padded_labels, chunks=chunks, compression="gzip", compression_opts=4)
 69
 70    return out_path
 71
 72
 73def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
 74    """Download the XPRESS training data.
 75
 76    Args:
 77        path: Filepath to a folder where the data will be stored.
 78        download: Whether to download the data if it is not present.
 79
 80    Returns:
 81        Filepaths for raw and label data.
 82    """
 83    os.makedirs(path, exist_ok=True)
 84    raw_path = os.path.join(path, "xpress-training-raw.h5")
 85    label_path = os.path.join(path, "xpress-training-voxel-labels.h5")
 86
 87    util.download_source(raw_path, URLS["raw"], download, checksum=None)
 88    util.download_source(label_path, URLS["labels"], download, checksum=None)
 89
 90    merged_path = os.path.join(path, "xpress-training.h5")
 91    _merge_to_single_h5(raw_path, label_path, merged_path)
 92
 93    return merged_path, merged_path
 94
 95
 96def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
 97    """Get paths to the XPRESS training data."""
 98    return get_xpress_data(path, download)
 99
100
101def get_xpress_dataset(
102    path: Union[os.PathLike, str],
103    patch_shape: Tuple[int, int, int],
104    raw_key: Optional[str] = None,
105    label_key: Optional[str] = None,
106    download: bool = False,
107    **kwargs,
108) -> Dataset:
109    """Get the XPRESS dataset for voxel-wise segmentation.
110
111    Args:
112        path: Filepath to a folder where the data will be stored.
113        patch_shape: The patch shape to use for training.
114        raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
115        label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
116        download: Whether to download the data if it is not present.
117        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
118
119    Returns:
120        The segmentation dataset.
121    """
122    assert len(patch_shape) == 3
123    raw_path, label_path = get_xpress_paths(path, download)
124
125    return torch_em.default_segmentation_dataset(
126        raw_paths=[raw_path],
127        raw_key="raw" if raw_key is None else raw_key,
128        label_paths=[label_path],
129        label_key="labels" if label_key is None else label_key,
130        patch_shape=patch_shape,
131        is_seg_dataset=True,
132        **kwargs,
133    )
134
135
136def get_xpress_loader(
137    path: Union[os.PathLike, str],
138    batch_size: int,
139    patch_shape: Tuple[int, int, int],
140    raw_key: Optional[str] = None,
141    label_key: Optional[str] = None,
142    download: bool = False,
143    **kwargs,
144) -> DataLoader:
145    """Get the XPRESS dataloader for voxel-wise segmentation."""
146    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
147    dataset = get_xpress_dataset(
148        path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs
149    )
150    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'raw': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5', 'labels': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5'}
def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
74def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
75    """Download the XPRESS training data.
76
77    Args:
78        path: Filepath to a folder where the data will be stored.
79        download: Whether to download the data if it is not present.
80
81    Returns:
82        Filepaths for raw and label data.
83    """
84    os.makedirs(path, exist_ok=True)
85    raw_path = os.path.join(path, "xpress-training-raw.h5")
86    label_path = os.path.join(path, "xpress-training-voxel-labels.h5")
87
88    util.download_source(raw_path, URLS["raw"], download, checksum=None)
89    util.download_source(label_path, URLS["labels"], download, checksum=None)
90
91    merged_path = os.path.join(path, "xpress-training.h5")
92    _merge_to_single_h5(raw_path, label_path, merged_path)
93
94    return merged_path, merged_path

Download the XPRESS training data.

Arguments:
  • path: Filepath to a folder where the data will be stored.
  • download: Whether to download the data if it is not present.
Returns:

Filepaths for raw and label data.

def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
97def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
98    """Get paths to the XPRESS training data."""
99    return get_xpress_data(path, download)

Get paths to the XPRESS training data.

def get_xpress_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
102def get_xpress_dataset(
103    path: Union[os.PathLike, str],
104    patch_shape: Tuple[int, int, int],
105    raw_key: Optional[str] = None,
106    label_key: Optional[str] = None,
107    download: bool = False,
108    **kwargs,
109) -> Dataset:
110    """Get the XPRESS dataset for voxel-wise segmentation.
111
112    Args:
113        path: Filepath to a folder where the data will be stored.
114        patch_shape: The patch shape to use for training.
115        raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
116        label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
117        download: Whether to download the data if it is not present.
118        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
119
120    Returns:
121        The segmentation dataset.
122    """
123    assert len(patch_shape) == 3
124    raw_path, label_path = get_xpress_paths(path, download)
125
126    return torch_em.default_segmentation_dataset(
127        raw_paths=[raw_path],
128        raw_key="raw" if raw_key is None else raw_key,
129        label_paths=[label_path],
130        label_key="labels" if label_key is None else label_key,
131        patch_shape=patch_shape,
132        is_seg_dataset=True,
133        **kwargs,
134    )

Get the XPRESS dataset for voxel-wise segmentation.

Arguments:
  • path: Filepath to a folder where the data will be stored.
  • patch_shape: The patch shape to use for training.
  • raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
  • label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_xpress_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
137def get_xpress_loader(
138    path: Union[os.PathLike, str],
139    batch_size: int,
140    patch_shape: Tuple[int, int, int],
141    raw_key: Optional[str] = None,
142    label_key: Optional[str] = None,
143    download: bool = False,
144    **kwargs,
145) -> DataLoader:
146    """Get the XPRESS dataloader for voxel-wise segmentation."""
147    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
148    dataset = get_xpress_dataset(
149        path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs
150    )
151    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the XPRESS dataloader for voxel-wise segmentation.