torch_em.data.datasets.light_microscopy.xpress
The XPRESS dataset contains volumetric microscopy data with voxel-wise labels.
The training data is hosted at:
1"""The XPRESS dataset contains volumetric microscopy data with voxel-wise labels. 2 3The training data is hosted at: 4- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5 5- https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5 6""" 7 8import os 9from typing import Optional, Tuple, Union 10 11from torch.utils.data import Dataset, DataLoader 12 13import torch_em 14 15from .. import util 16 17 18URLS = { 19 "raw": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5", 20 "labels": "https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5", 21} 22 23 24def _default_chunks(shape): 25 # Simple heuristic: chunk along z and limit chunk extents to 64. 26 return tuple(min(64, int(s)) for s in shape) 27 28 29def _merge_to_single_h5(raw_path: Union[os.PathLike, str], label_path: Union[os.PathLike, str], out_path: str): 30 if os.path.exists(out_path): 31 return out_path 32 33 import h5py 34 import numpy as np 35 36 with h5py.File(raw_path, "r") as fr, h5py.File(label_path, "r") as fl, h5py.File(out_path, "w") as fo: 37 raw_ds_in = fr["volumes/raw"] 38 labels_ds_in = fl["volumes/labels"] 39 40 raw_resolution = np.array(raw_ds_in.attrs.get("resolution", [1, 1, 1])) 41 label_offset = np.array(labels_ds_in.attrs.get("offset", [0, 0, 0])) 42 43 # Convert the label offset from world coordinates to voxel coordinates in the raw volume. 44 voxel_offset = (label_offset / raw_resolution).astype(int) 45 labels_arr = labels_ds_in[...] 46 47 # Crop the raw with extra context (128 px padding per side) around the labeled region. 48 context_pad = 128 49 raw_shape = np.array(raw_ds_in.shape) 50 starts = np.clip(voxel_offset - context_pad, 0, raw_shape) 51 ends = np.clip(voxel_offset + np.array(labels_arr.shape) + context_pad, 0, raw_shape) 52 53 raw_slices = tuple(slice(int(s), int(e)) for s, e in zip(starts, ends)) 54 raw_arr = raw_ds_in[raw_slices] 55 56 # Place labels inside a zero-padded volume matching the (padded) raw crop. 57 label_insert_offset = voxel_offset - starts 58 padded_labels = np.zeros(raw_arr.shape, dtype="int64") 59 label_slices = tuple( 60 slice(int(o), int(o) + s) for o, s in zip(label_insert_offset, labels_arr.shape) 61 ) 62 padded_labels[label_slices] = labels_arr 63 64 chunks = _default_chunks(raw_arr.shape) 65 66 fo.create_dataset("raw", data=raw_arr, chunks=chunks, compression="gzip", compression_opts=4) 67 fo.create_dataset("labels", data=padded_labels, chunks=chunks, compression="gzip", compression_opts=4) 68 69 return out_path 70 71 72def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 73 """Download the XPRESS training data. 74 75 Args: 76 path: Filepath to a folder where the data will be stored. 77 download: Whether to download the data if it is not present. 78 79 Returns: 80 Filepaths for raw and label data. 81 """ 82 os.makedirs(path, exist_ok=True) 83 raw_path = os.path.join(path, "xpress-training-raw.h5") 84 label_path = os.path.join(path, "xpress-training-voxel-labels.h5") 85 86 util.download_source(raw_path, URLS["raw"], download, checksum=None) 87 util.download_source(label_path, URLS["labels"], download, checksum=None) 88 89 merged_path = os.path.join(path, "xpress-training.h5") 90 _merge_to_single_h5(raw_path, label_path, merged_path) 91 92 return merged_path, merged_path 93 94 95def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 96 """Get paths to the XPRESS training data.""" 97 return get_xpress_data(path, download) 98 99 100def get_xpress_dataset( 101 path: Union[os.PathLike, str], 102 patch_shape: Tuple[int, int, int], 103 raw_key: Optional[str] = None, 104 label_key: Optional[str] = None, 105 download: bool = False, 106 **kwargs, 107) -> Dataset: 108 """Get the XPRESS dataset for voxel-wise segmentation. 109 110 Args: 111 path: Filepath to a folder where the data will be stored. 112 patch_shape: The patch shape to use for training. 113 raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible. 114 label_key: The HDF5 key for the label data. If None, it will be inferred when possible. 115 download: Whether to download the data if it is not present. 116 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 117 118 Returns: 119 The segmentation dataset. 120 """ 121 assert len(patch_shape) == 3 122 raw_path, label_path = get_xpress_paths(path, download) 123 124 return torch_em.default_segmentation_dataset( 125 raw_paths=[raw_path], 126 raw_key="raw" if raw_key is None else raw_key, 127 label_paths=[label_path], 128 label_key="labels" if label_key is None else label_key, 129 patch_shape=patch_shape, 130 is_seg_dataset=True, 131 **kwargs, 132 ) 133 134 135def get_xpress_loader( 136 path: Union[os.PathLike, str], 137 batch_size: int, 138 patch_shape: Tuple[int, int, int], 139 raw_key: Optional[str] = None, 140 label_key: Optional[str] = None, 141 download: bool = False, 142 **kwargs, 143) -> DataLoader: 144 """Get the XPRESS dataloader for voxel-wise segmentation.""" 145 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 146 dataset = get_xpress_dataset( 147 path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs 148 ) 149 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'raw': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-raw.h5', 'labels': 'https://github.com/htem/xpress-challenge-files/releases/download/v1.0/xpress-training-voxel-labels.h5'}
def
get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
73def get_xpress_data(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 74 """Download the XPRESS training data. 75 76 Args: 77 path: Filepath to a folder where the data will be stored. 78 download: Whether to download the data if it is not present. 79 80 Returns: 81 Filepaths for raw and label data. 82 """ 83 os.makedirs(path, exist_ok=True) 84 raw_path = os.path.join(path, "xpress-training-raw.h5") 85 label_path = os.path.join(path, "xpress-training-voxel-labels.h5") 86 87 util.download_source(raw_path, URLS["raw"], download, checksum=None) 88 util.download_source(label_path, URLS["labels"], download, checksum=None) 89 90 merged_path = os.path.join(path, "xpress-training.h5") 91 _merge_to_single_h5(raw_path, label_path, merged_path) 92 93 return merged_path, merged_path
Download the XPRESS training data.
Arguments:
- path: Filepath to a folder where the data will be stored.
- download: Whether to download the data if it is not present.
Returns:
Filepaths for raw and label data.
def
get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]:
96def get_xpress_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[str, str]: 97 """Get paths to the XPRESS training data.""" 98 return get_xpress_data(path, download)
Get paths to the XPRESS training data.
def
get_xpress_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
101def get_xpress_dataset( 102 path: Union[os.PathLike, str], 103 patch_shape: Tuple[int, int, int], 104 raw_key: Optional[str] = None, 105 label_key: Optional[str] = None, 106 download: bool = False, 107 **kwargs, 108) -> Dataset: 109 """Get the XPRESS dataset for voxel-wise segmentation. 110 111 Args: 112 path: Filepath to a folder where the data will be stored. 113 patch_shape: The patch shape to use for training. 114 raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible. 115 label_key: The HDF5 key for the label data. If None, it will be inferred when possible. 116 download: Whether to download the data if it is not present. 117 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 118 119 Returns: 120 The segmentation dataset. 121 """ 122 assert len(patch_shape) == 3 123 raw_path, label_path = get_xpress_paths(path, download) 124 125 return torch_em.default_segmentation_dataset( 126 raw_paths=[raw_path], 127 raw_key="raw" if raw_key is None else raw_key, 128 label_paths=[label_path], 129 label_key="labels" if label_key is None else label_key, 130 patch_shape=patch_shape, 131 is_seg_dataset=True, 132 **kwargs, 133 )
Get the XPRESS dataset for voxel-wise segmentation.
Arguments:
- path: Filepath to a folder where the data will be stored.
- patch_shape: The patch shape to use for training.
- raw_key: The HDF5 key for the raw data. If None, it will be inferred when possible.
- label_key: The HDF5 key for the label data. If None, it will be inferred when possible.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
def
get_xpress_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int, int], raw_key: Optional[str] = None, label_key: Optional[str] = None, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
136def get_xpress_loader( 137 path: Union[os.PathLike, str], 138 batch_size: int, 139 patch_shape: Tuple[int, int, int], 140 raw_key: Optional[str] = None, 141 label_key: Optional[str] = None, 142 download: bool = False, 143 **kwargs, 144) -> DataLoader: 145 """Get the XPRESS dataloader for voxel-wise segmentation.""" 146 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 147 dataset = get_xpress_dataset( 148 path, patch_shape, raw_key=raw_key, label_key=label_key, download=download, **ds_kwargs 149 ) 150 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the XPRESS dataloader for voxel-wise segmentation.