torch_em.data.datasets.light_microscopy.xpress
The XPRESS dataset contains volumetric microscopy data with voxel-wise labels.
The training data is hosted at:
1"""The XPRESS dataset contains volumetric microscopy data with voxel-wise labels. 2 3The training data is hosted at: 4- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5 5- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5 6""" 7 8import os 9from typing import Optional, Tuple, Union 10 11from torch.utils.data import Dataset, DataLoader 12 13import torch_em 14 15from .. import util 16 17 18URLS = { 19 "raw": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5", 20 "labels": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5", 21} 22 23 24def _default_chunks(shape): 25 # Simple heuristic: chunk along z and limit chunk extents to 64. 26 return tuple(min(64, int(s)) for s in shape) 27 28 29 30def _merge_to_single_h5(raw_path: Union[os.PathLike, str], label_path: Union[os.PathLike, str], out_path: str): 31 if os.path.exists(out_path): 32 return out_path 33 34 import h5py 35 import numpy as np 36 37 with h5py.File(raw_path, "r") as fr, h5py.File(label_path, "r") as fl, h5py.File(out_path, "w") as fo: 38 raw_ds_in = fr["volumes/raw"] 39 labels_ds_in = fl["volumes/labels"] 40 41 raw_resolution = np.array(raw_ds_in.attrs.get("resolution", [1, 1, 1])) 42 label_offset = np.array(labels_ds_in.attrs.get("offset", [0, 0, 0])) 43 44 # Convert the label offset from world coordinates to voxel coordinates in the raw volume. 45 voxel_offset = (label_offset / raw_resolution).astype(int) 46 labels_arr = labels_ds_in[...] 47 48 # Crop the raw with extra context (128 px padding per side) around the labeled region. 49 context_pad = 128 50 raw_shape = np.array(raw_ds_in.shape) 51 starts = np.clip(voxel_offset - context_pad, 0, raw_shape) 52 ends = np.clip(voxel_offset + np.array(labels_arr.shape) + context_pad, 0, raw_shape) 53 54 raw_slices = tuple(slice(int(s), int(e)) for s, e in zip(starts, ends)) 55 raw_arr = raw_ds_in[raw_slices] 56 57 # Place labels inside a zero-padded volume matching the (padded) raw crop. 58 label_insert_offset = voxel_offset - starts 59 padded_labels = np.zeros(raw_arr.shape, dtype="int64") 60 label_slices = tuple( 61 slice(int(o), int(o) + s) for o, s in zip(label_insert_offset, labels_arr.shape) 62 ) 63 padded_labels[label_slices] = labels_arr 64 65 chunks = _default_chunks(raw_arr.shape) 66 67 fo.create_dataset("raw", data=raw_arr, chunks=chunks, compression="gzip", compression_opts=4) 68 fo.create_dataset("labels", data=padded_labels, chunks=chunks, compression="gzip", compression_opts=4) 69 70 return out_path 71 72 73def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 74 """Download the XPRESS training data. 75 76 Args: 77 path: Filepath to a folder where the data will be stored. 78 download: Whether to download the data if it is not present. 79 80 Returns: 81 Filepaths for raw and label data. 82 """ 83 os.makedirs(path, exist_ok=True) 84 raw_path = os.path.join(path, "xpress-training-raw.h5") 85 label_path = os.path.join(path, "xpress-training-voxel-labels.h5") 86 87 util.download_source(raw_path, URLS["raw"], download, checksum=None) 88 util.download_source(label_path, URLS["labels"], download, checksum=None) 89 90 merged_path = os.path.join(path, "xpress-training.h5") 91 _merge_to_single_h5(raw_path, label_path, merged_path) 92 93 return merged_path, merged_path 94 95 96def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 97 """Get paths to the XPRESS training data.""" 98 return get_xpress_data(path, download) 99 100 101def get_xpress_dataset( 102 path: Union[os.PathLike, str], 103 patch_shape: Tuple[int, int, int], 104 raw_key: Optional[str] = None, 105 label_key: Optional[str] = None, 106 download: bool = False, 107 **kwargs, 108) -> Dataset: 109 """Get the XPRESS dataset for voxel-wise segmentation. 110 111 Args: 112 path: Filepath to a folder where the data will be stored. 113 patch_shape: The patch shape to use for training. 114 raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible. 115 label_key: The HDF5 key for the label data. If None, it will be inferred when possible. 116 download: Whether to download the data if it is not present. 117 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 118 119 Returns: 120 The segmentation dataset. 121 """ 122 assert len(patch_shape) == 3 123 raw_path, label_path = get_xpress_paths(path, download) 124 125 return torch_em.default_segmentation_dataset( 126 raw_paths=[raw_path], 127 raw_key="raw" if raw_key is None else raw_key, 128 label_paths=[label_path], 129 label_key="labels" if label_key is None else label_key, 130 patch_shape=patch_shape, 131 is_seg_dataset=True, 132 **kwargs, 133 ) 134 135 136def get_xpress_loader( 137 path: Union[os.PathLike, str], 138 batch_size: int, 139 patch_shape: Tuple[int, int, int], 140 raw_key: Optional[str] = None, 141 label_key: Optional[str] = None, 142 download: bool = False, 143 **kwargs, 144) -> DataLoader: 145 """Get the XPRESS dataloader for voxel-wise segmentation.""" 146 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 147 dataset = get_xpress_dataset( 148 path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs 149 ) 150 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'raw': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5', 'labels': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5'}
def
get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
74def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 75 """Download the XPRESS training data. 76 77 Args: 78 path: Filepath to a folder where the data will be stored. 79 download: Whether to download the data if it is not present. 80 81 Returns: 82 Filepaths for raw and label data. 83 """ 84 os.makedirs(path, exist_ok=True) 85 raw_path = os.path.join(path, "xpress-training-raw.h5") 86 label_path = os.path.join(path, "xpress-training-voxel-labels.h5") 87 88 util.download_source(raw_path, URLS["raw"], download, checksum=None) 89 util.download_source(label_path, URLS["labels"], download, checksum=None) 90 91 merged_path = os.path.join(path, "xpress-training.h5") 92 _merge_to_single_h5(raw_path, label_path, merged_path) 93 94 return merged_path, merged_path
Download the XPRESS training data.
Arguments:
- path: Filepath to a folder where the data will be stored.
- download: Whether to download the data if it is not present.
Returns:
Filepaths for raw and label data.
def
get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
97def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 98 """Get paths to the XPRESS training data.""" 99 return get_xpress_data(path, download)
Get paths to the XPRESS training data.
def
get_xpress_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
102def get_xpress_dataset( 103 path: Union[os.PathLike, str], 104 patch_shape: Tuple[int, int, int], 105 raw_key: Optional[str] = None, 106 label_key: Optional[str] = None, 107 download: bool = False, 108 **kwargs, 109) -> Dataset: 110 """Get the XPRESS dataset for voxel-wise segmentation. 111 112 Args: 113 path: Filepath to a folder where the data will be stored. 114 patch_shape: The patch shape to use for training. 115 raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible. 116 label_key: The HDF5 key for the label data. If None, it will be inferred when possible. 117 download: Whether to download the data if it is not present. 118 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 119 120 Returns: 121 The segmentation dataset. 122 """ 123 assert len(patch_shape) == 3 124 raw_path, label_path = get_xpress_paths(path, download) 125 126 return torch_em.default_segmentation_dataset( 127 raw_paths=[raw_path], 128 raw_key="raw" if raw_key is None else raw_key, 129 label_paths=[label_path], 130 label_key="labels" if label_key is None else label_key, 131 patch_shape=patch_shape, 132 is_seg_dataset=True, 133 **kwargs, 134 )
Get the XPRESS dataset for voxel-wise segmentation.
Arguments:
- path: Filepath to a folder where the data will be stored.
- patch_shape: The patch shape to use for training.
- raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
- label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
def
get_xpress_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
137def get_xpress_loader( 138 path: Union[os.PathLike, str], 139 batch_size: int, 140 patch_shape: Tuple[int, int, int], 141 raw_key: Optional[str] = None, 142 label_key: Optional[str] = None, 143 download: bool = False, 144 **kwargs, 145) -> DataLoader: 146 """Get the XPRESS dataloader for voxel-wise segmentation.""" 147 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 148 dataset = get_xpress_dataset( 149 path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs 150 ) 151 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the XPRESS dataloader for voxel-wise segmentation.