torch_em.data.datasets.light_microscopy.xpress

  1"""The XPRESS dataset contains volumetric microscopy data with voxel-wise labels.
  2
  3The training data is hosted at:
  4- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5
  5- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5
  6"""
  7
  8import os
  9from typing import Optional, Tuple, Union
 10
 11from torch.utils.data import Dataset, DataLoader
 12
 13import torch_em
 14
 15from .. import util
 16
 17
 18URLS = {
 19    "raw": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5",
 20    "labels": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5",
 21}
 22
 23
 24def _default_chunks(shape):
 25    # Simple heuristic: chunk along z and limit chunk extents to 64.
 26    return tuple(min(64, int(s)) for s in shape)
 27
 28
 29def _merge_to_single_h5(raw_path: Union[os.PathLike, str], label_path: Union[os.PathLike, str], out_path: str):
 30    if os.path.exists(out_path):
 31        return out_path
 32
 33    import h5py
 34    import numpy as np
 35
 36    with h5py.File(raw_path, "r") as fr, h5py.File(label_path, "r") as fl, h5py.File(out_path, "w") as fo:
 37        raw_ds_in = fr["volumes/raw"]
 38        labels_ds_in = fl["volumes/labels"]
 39
 40        raw_resolution = np.array(raw_ds_in.attrs.get("resolution", [1, 1, 1]))
 41        label_offset = np.array(labels_ds_in.attrs.get("offset", [0, 0, 0]))
 42
 43        # Convert the label offset from world coordinates to voxel coordinates in the raw volume.
 44        voxel_offset = (label_offset / raw_resolution).astype(int)
 45        labels_arr = labels_ds_in[...]
 46
 47        # Crop the raw with extra context (128 px padding per side) around the labeled region.
 48        context_pad = 128
 49        raw_shape = np.array(raw_ds_in.shape)
 50        starts = np.clip(voxel_offset - context_pad, 0, raw_shape)
 51        ends = np.clip(voxel_offset + np.array(labels_arr.shape) + context_pad, 0, raw_shape)
 52
 53        raw_slices = tuple(slice(int(s), int(e)) for s, e in zip(starts, ends))
 54        raw_arr = raw_ds_in[raw_slices]
 55
 56        # Place labels inside a zero-padded volume matching the (padded) raw crop.
 57        label_insert_offset = voxel_offset - starts
 58        padded_labels = np.zeros(raw_arr.shape, dtype="int64")
 59        label_slices = tuple(
 60            slice(int(o), int(o) + s) for o, s in zip(label_insert_offset, labels_arr.shape)
 61        )
 62        padded_labels[label_slices] = labels_arr
 63
 64        chunks = _default_chunks(raw_arr.shape)
 65
 66        fo.create_dataset("raw", data=raw_arr, chunks=chunks, compression="gzip", compression_opts=4)
 67        fo.create_dataset("labels", data=padded_labels, chunks=chunks, compression="gzip", compression_opts=4)
 68
 69    return out_path
 70
 71
 72def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
 73    """Download the XPRESS training data.
 74
 75    Args:
 76        path: Filepath to a folder where the data will be stored.
 77        download: Whether to download the data if it is not present.
 78
 79    Returns:
 80        Filepaths for raw and label data.
 81    """
 82    os.makedirs(path, exist_ok=True)
 83    raw_path = os.path.join(path, "xpress-training-raw.h5")
 84    label_path = os.path.join(path, "xpress-training-voxel-labels.h5")
 85
 86    util.download_source(raw_path, URLS["raw"], download, checksum=None)
 87    util.download_source(label_path, URLS["labels"], download, checksum=None)
 88
 89    merged_path = os.path.join(path, "xpress-training.h5")
 90    _merge_to_single_h5(raw_path, label_path, merged_path)
 91
 92    return merged_path, merged_path
 93
 94
 95def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
 96    """Get paths to the XPRESS training data."""
 97    return get_xpress_data(path, download)
 98
 99
100def get_xpress_dataset(
101    path: Union[os.PathLike, str],
102    patch_shape: Tuple[int, int, int],
103    raw_key: Optional[str] = None,
104    label_key: Optional[str] = None,
105    download: bool = False,
106    **kwargs,
107) -> Dataset:
108    """Get the XPRESS dataset for voxel-wise segmentation.
109
110    Args:
111        path: Filepath to a folder where the data will be stored.
112        patch_shape: The patch shape to use for training.
113        raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
114        label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
115        download: Whether to download the data if it is not present.
116        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
117
118    Returns:
119        The segmentation dataset.
120    """
121    assert len(patch_shape) == 3
122    raw_path, label_path = get_xpress_paths(path, download)
123
124    return torch_em.default_segmentation_dataset(
125        raw_paths=[raw_path],
126        raw_key="raw" if raw_key is None else raw_key,
127        label_paths=[label_path],
128        label_key="labels" if label_key is None else label_key,
129        patch_shape=patch_shape,
130        is_seg_dataset=True,
131        **kwargs,
132    )
133
134
135def get_xpress_loader(
136    path: Union[os.PathLike, str],
137    batch_size: int,
138    patch_shape: Tuple[int, int, int],
139    raw_key: Optional[str] = None,
140    label_key: Optional[str] = None,
141    download: bool = False,
142    **kwargs,
143) -> DataLoader:
144    """Get the XPRESS dataloader for voxel-wise segmentation."""
145    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
146    dataset = get_xpress_dataset(
147        path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs
148    )
149    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'raw': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5', 'labels': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5'}
def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
73def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
74    """Download the XPRESS training data.
75
76    Args:
77        path: Filepath to a folder where the data will be stored.
78        download: Whether to download the data if it is not present.
79
80    Returns:
81        Filepaths for raw and label data.
82    """
83    os.makedirs(path, exist_ok=True)
84    raw_path = os.path.join(path, "xpress-training-raw.h5")
85    label_path = os.path.join(path, "xpress-training-voxel-labels.h5")
86
87    util.download_source(raw_path, URLS["raw"], download, checksum=None)
88    util.download_source(label_path, URLS["labels"], download, checksum=None)
89
90    merged_path = os.path.join(path, "xpress-training.h5")
91    _merge_to_single_h5(raw_path, label_path, merged_path)
92
93    return merged_path, merged_path

Download the XPRESS training data.

Arguments:
  • path: Filepath to a folder where the data will be stored.
  • download: Whether to download the data if it is not present.
Returns:

Filepaths for raw and label data.

def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
96def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
97    """Get paths to the XPRESS training data."""
98    return get_xpress_data(path, download)

Get paths to the XPRESS training data.

def get_xpress_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
101def get_xpress_dataset(
102    path: Union[os.PathLike, str],
103    patch_shape: Tuple[int, int, int],
104    raw_key: Optional[str] = None,
105    label_key: Optional[str] = None,
106    download: bool = False,
107    **kwargs,
108) -> Dataset:
109    """Get the XPRESS dataset for voxel-wise segmentation.
110
111    Args:
112        path: Filepath to a folder where the data will be stored.
113        patch_shape: The patch shape to use for training.
114        raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
115        label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
116        download: Whether to download the data if it is not present.
117        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
118
119    Returns:
120        The segmentation dataset.
121    """
122    assert len(patch_shape) == 3
123    raw_path, label_path = get_xpress_paths(path, download)
124
125    return torch_em.default_segmentation_dataset(
126        raw_paths=[raw_path],
127        raw_key="raw" if raw_key is None else raw_key,
128        label_paths=[label_path],
129        label_key="labels" if label_key is None else label_key,
130        patch_shape=patch_shape,
131        is_seg_dataset=True,
132        **kwargs,
133    )

Get the XPRESS dataset for voxel-wise segmentation.

Arguments:
  • path: Filepath to a folder where the data will be stored.
  • patch_shape: The patch shape to use for training.
  • raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
  • label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_xpress_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
136def get_xpress_loader(
137    path: Union[os.PathLike, str],
138    batch_size: int,
139    patch_shape: Tuple[int, int, int],
140    raw_key: Optional[str] = None,
141    label_key: Optional[str] = None,
142    download: bool = False,
143    **kwargs,
144) -> DataLoader:
145    """Get the XPRESS dataloader for voxel-wise segmentation."""
146    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
147    dataset = get_xpress_dataset(
148        path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs
149    )
150    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the XPRESS dataloader for voxel-wise segmentation.