torch_em.data.datasets.histopathology.hest
The HEST-1k dataset contains 1,276 paired H&E whole-slide images and spatial transcriptomics profiles across 26 organ types. For each sample, pre-extracted 224x224 H&E patches at 0.5 um/px, CellViT nuclei instance segmentation masks, Xenium DAPI-derived nucleus boundaries (for Xenium samples), and cell-level spatial transcriptomics gene expression profiles are available.
Three types of segmentation labels are supported:
- 'instances': nuclei instance masks derived from CellViT (H&E-based, all samples).
- 'xenium_instances': nuclei instance masks from DAPI segmentation (Xenium samples only).
- 'semantic': cell-type semantic masks derived from spatial transcriptomics via Leiden clustering and PanglaoDB marker-gene voting (Xenium samples only). Classes: 0=background, 1=Epithelial, 2=Inflammatory, 3=Connective, 4=Neoplastic, 5=Unknown.
This dataset is used in the paper https://doi.org/10.48550/arXiv.2604.23481 as a scalable alternative to manually annotated datasets for nuclei segmentation and classification training.
The dataset is located at https://huggingface.co/datasets/MahmoodLab/hest. This dataset is from the following publication:
- Jaume et al. (2024): https://doi.org/10.48550/arXiv.2406.16192 Please cite it if you use this dataset in your research.
NOTE: Requires huggingface_hub for download: pip install huggingface_hub
NOTE: Requires geopandas, rasterio, and scipy for preprocessing: pip install geopandas rasterio scipy
NOTE: Requires scanpy, python-igraph, and leidenalg for semantic labels: pip install scanpy igraph leidenalg
NOTE: The full dataset is ~2 TB. Use the organs argument to download only a subset.
1"""The HEST-1k dataset contains 1,276 paired H&E whole-slide images and spatial transcriptomics 2profiles across 26 organ types. For each sample, pre-extracted 224x224 H&E patches at 0.5 um/px, 3CellViT nuclei instance segmentation masks, Xenium DAPI-derived nucleus boundaries (for Xenium 4samples), and cell-level spatial transcriptomics gene expression profiles are available. 5 6Three types of segmentation labels are supported: 7- 'instances': nuclei instance masks derived from CellViT (H&E-based, all samples). 8- 'xenium_instances': nuclei instance masks from DAPI segmentation (Xenium samples only). 9- 'semantic': cell-type semantic masks derived from spatial transcriptomics via Leiden clustering 10 and PanglaoDB marker-gene voting (Xenium samples only). Classes: 0=background, 1=Epithelial, 11 2=Inflammatory, 3=Connective, 4=Neoplastic, 5=Unknown. 12 13This dataset is used in the paper https://doi.org/10.48550/arXiv.2604.23481 as a scalable 14alternative to manually annotated datasets for nuclei segmentation and classification training. 15 16The dataset is located at https://huggingface.co/datasets/MahmoodLab/hest. 17This dataset is from the following publication: 18- Jaume et al. (2024): https://doi.org/10.48550/arXiv.2406.16192 19Please cite it if you use this dataset in your research. 20 21NOTE: Requires huggingface_hub for download: pip install huggingface_hub 22NOTE: Requires geopandas, rasterio, and scipy for preprocessing: pip install geopandas rasterio scipy 23NOTE: Requires scanpy, python-igraph, and leidenalg for semantic labels: pip install scanpy igraph leidenalg 24NOTE: The full dataset is ~2 TB. Use the `organs` argument to download only a subset. 25""" 26 27import json 28import os 29import zipfile 30from glob import glob 31from tqdm import tqdm 32from natsort import natsorted 33from typing import Callable, Dict, List, Literal, Optional, Tuple, Union 34 35import h5py 36import numpy as np 37import torch 38from torch.utils.data import Dataset, DataLoader 39 40import torch_em 41 42 43HF_REPO = "MahmoodLab/hest" 44METADATA_FILENAME = "HEST_v1_3_0.csv" 45PANGLAODB_URL = "https://panglaodb.se/markers/PanglaoDB_markers_27_Mar_2020.tsv.gz" 46 47# Integer label for each cell-type category (0 = background). 48CELL_TYPE_LABELS = {"Epithelial": 1, "Inflammatory": 2, "Connective": 3, "Neoplastic": 4, "Unknown": 5} 49 50# Map from public label_choice strings to HDF5 dataset paths. 51LABEL_KEYS = { 52 "instances": "labels/instances/h&e", 53 "xenium_instances": "labels/instances/xenium", 54 "semantic": "labels/semantic/st", 55} 56 57# Organs present in both HEST-1k and PanNuke (used in arXiv 2604.23481). 58PANNUKE_ORGANS = [ 59 "Breast", "Colon", "Kidney", "Liver", "Lung", "Ovarian", "Pancreatic", "Prostate", "Skin", "Stomach", 60] 61 62# Keyword fragments for mapping PanglaoDB cell type names to coarse categories. 63EPITHELIAL_KEYWORDS = [ 64 "acinar", "airway epithelial", "airway goblet", "alveolar type", "alpha cell", "basal cell", 65 "beta cell", "cholangiocyte", "ciliated", "clara", "crypt", "delta cell", "ductal", 66 "enterocyte", "epithelial", "goblet", "hepatocyte", "keratinocyte", "mesothelial", 67 "paneth", "pneumocyte", "proximal tubule", "renal tubule", "squamous", "thyroid", 68 "trophoblast", "tuft", "urothelial", 69] 70INFLAMMATORY_KEYWORDS = [ 71 "alveolar macrophage", "b cell", "basophil", "dendritic", "eosinophil", 72 "innate lymphoid", "lymphocyte", "macrophage", "mast cell", "monocyte", 73 "natural killer", "neutrophil", "nk cell", "plasma cell", "regulatory t", "t cell", 74] 75CONNECTIVE_KEYWORDS = [ 76 "adipocyte", "chondrocyte", "endothelial", "fibroblast", "mesenchymal", 77 "myofibroblast", "osteoblast", "osteoclast", "pericyte", "smooth muscle", 78 "stellate", "stromal", "vascular", 79] 80 81# Well-known cancer-associated genes (COSMIC Cancer Gene Census, tier 1). 82CANCER_GENES = { 83 "ABL1", "AKT1", "ALK", "APC", "ATM", "BRAF", "BRCA1", "BRCA2", "CDH1", "CDKN2A", 84 "CTNNB1", "EGFR", "ERBB2", "ESR1", "EZH2", "FBXW7", "FGFR1", "FGFR2", "FGFR3", 85 "FLT3", "GATA3", "GNAQ", "GNAS", "HNF1A", "HRAS", "IDH1", "IDH2", "JAK2", "KIT", 86 "KRAS", "MAP2K1", "MDM2", "MET", "MLH1", "MSH2", "MSH6", "MTOR", "MYC", "MYCN", 87 "NF1", "NF2", "NFE2L2", "NOTCH1", "NOTCH2", "NRAS", "PALB2", "PBRM1", "PIK3CA", 88 "PIK3R1", "PMS2", "POLE", "PTCH1", "PTEN", "RB1", "RET", "RNF43", "SETD2", "SF3B1", 89 "SMAD4", "SMARCA4", "SMARCB1", "SMO", "STK11", "TERT", "TET2", "TP53", "TSC1", 90 "TSC2", "VHL", "BAP1", "CDK12", "CHEK2", "CREBBP", "DNMT3A", "EP300", "FANCD2", 91 "KDM5C", "KDM6A", "KEAP1", "MAP3K1", "MUTYH", "NBN", "PDGFRA", "PPP2R1A", "RAD51C", 92 "RUNX1", "SDHA", "SDHB", "SDHC", "SDHD", "SUFU", "TP63", "XRCC2", "AXIN1", "AXIN2", 93 "BRIP1", "CHD4", "ELOC", "FANCA", "FH", "FLCN", "MRE11", "RAD50", "RAD51B", "RAD51D", 94} 95 96 97def _download_hest(path: str, sample_ids: List[str], include_xenium: bool, include_st: bool) -> None: 98 try: 99 from huggingface_hub import snapshot_download 100 except ImportError: 101 raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub") 102 103 patterns = [METADATA_FILENAME] 104 for sid in sample_ids: 105 patterns += [f"patches/{sid}.h5", f"cellvit_seg/{sid}_cellvit_seg.geojson.zip"] 106 if include_xenium: 107 patterns += [f"xenium_seg/{sid}_xenium_nucleus_seg.parquet"] 108 if include_st: 109 patterns += [f"st/{sid}.h5ad"] 110 111 os.makedirs(path, exist_ok=True) 112 snapshot_download(repo_id=HF_REPO, repo_type="dataset", local_dir=path, allow_patterns=patterns) 113 114 115def _load_metadata(path: str) -> "pd.DataFrame": # noqa 116 try: 117 import pandas as pd 118 except ImportError: 119 raise ImportError("pandas is required. Install with: pip install pandas") 120 121 csv_path = os.path.join(path, METADATA_FILENAME) 122 if not os.path.exists(csv_path): 123 raise RuntimeError(f"Metadata not found at {csv_path}. Run get_hest_data() first.") 124 return pd.read_csv(csv_path) 125 126 127def _filter_sample_ids(path: str, organs: Optional[List[str]]) -> List[str]: 128 meta = _load_metadata(path) 129 if organs is not None: 130 meta = meta[meta["organ"].isin(organs)] 131 return meta["id"].tolist() 132 133 134def _unzip_cellvit(zip_path: str, out_dir: str) -> Optional[str]: 135 if not os.path.exists(zip_path): 136 return None 137 # Strip both extensions: "SAMPLEID_cellvit_seg.geojson.zip" -> "SAMPLEID" 138 sample_id = os.path.basename(zip_path).replace("_cellvit_seg.geojson.zip", "") 139 extract_dir = os.path.join(out_dir, sample_id) 140 if not os.path.exists(extract_dir): 141 with zipfile.ZipFile(zip_path, "r") as zf: 142 zf.extractall(extract_dir) 143 matches = glob(os.path.join(extract_dir, "**", "*.geojson"), recursive=True) 144 return matches[0] if matches else None 145 146 147def _gdf_from_xenium_parquet(parquet_path: str) -> "gpd.GeoDataFrame": # noqa 148 """Load a Xenium nucleus segmentation parquet into a GeoDataFrame. 149 150 Expected format: index is cell_id, single 'geometry' column with WKB-encoded polygons. 151 """ 152 try: 153 import pandas as pd 154 import geopandas as gpd 155 import shapely 156 except ImportError: 157 raise ImportError("geopandas and shapely are required. Install with: pip install geopandas rasterio") 158 159 df = pd.read_parquet(parquet_path) 160 geometries = shapely.from_wkb(df["geometry"].values) 161 return gpd.GeoDataFrame({"cell_id": df.index.astype(str), "geometry": geometries}, geometry="geometry") 162 163 164def _gdf_from_cellvit_geojson(geojson_path: str) -> "gpd.GeoDataFrame": # noqa 165 """Load CellViT segmentation GeoJSON into a GeoDataFrame with one row per nucleus. 166 167 The file is a JSON list of features with MultiPolygon geometries (one per cell-type class). 168 Each MultiPolygon is exploded into individual Polygon rows. 169 """ 170 try: 171 import geopandas as gpd 172 from shapely.geometry import shape, MultiPolygon 173 except ImportError: 174 raise ImportError("geopandas and shapely are required. Install with: pip install geopandas rasterio") 175 176 with open(geojson_path) as fh: 177 data = json.load(fh) 178 179 records = [] 180 for feat in data: 181 geom = shape(feat["geometry"]) 182 if isinstance(geom, MultiPolygon): 183 for poly in geom.geoms: 184 records.append({"geometry": poly}) 185 else: 186 records.append({"geometry": geom}) 187 188 if not records: 189 return gpd.GeoDataFrame(columns=["geometry"]) 190 return gpd.GeoDataFrame(records, geometry="geometry") 191 192 193def _rasterize_patch_instances( 194 patch_x: int, 195 patch_y: int, 196 patch_size: int, 197 cells_gdf: "gpd.GeoDataFrame", # noqa 198 native_scale: float = 1.0, 199) -> np.ndarray: 200 """Rasterize nucleus polygons within one patch to an instance mask. 201 202 native_scale: native WSI pixels per 0.5 um/px patch pixel (= 0.5 / pixel_size_um). 203 Patches are stored at 0.5 um/px but cell coords are in native WSI pixel space. 204 """ 205 try: 206 from shapely.geometry import box 207 from shapely.affinity import translate, scale as affine_scale 208 from rasterio.features import rasterize as rio_rasterize 209 except ImportError: 210 raise ImportError("rasterio and shapely are required. Install with: pip install geopandas rasterio") 211 212 native_size = round(patch_size * native_scale) 213 patch_box = box(patch_x, patch_y, patch_x + native_size, patch_y + native_size) 214 local = cells_gdf[cells_gdf.geometry.intersects(patch_box)].copy() 215 mask = np.zeros((patch_size, patch_size), dtype=np.int32) 216 if local.empty: 217 return mask 218 219 inv = 1.0 / native_scale 220 local["geometry"] = local["geometry"].apply( 221 lambda g: affine_scale(translate(g, xoff=-patch_x, yoff=-patch_y), xfact=inv, yfact=inv, origin=(0, 0)) 222 ) 223 shapes = ((geom, i + 1) for i, geom in enumerate(local.geometry)) 224 return rio_rasterize(shapes, out_shape=(patch_size, patch_size), fill=0, dtype=np.int32) 225 226 227def _rasterize_patch_semantic( 228 patch_x: int, 229 patch_y: int, 230 patch_size: int, 231 cells_gdf: "gpd.GeoDataFrame", # noqa 232 spot_labels: np.ndarray, 233 native_scale: float = 1.0, 234 spot_tree=None, 235) -> np.ndarray: 236 """Rasterize nucleus polygons within one patch to a semantic (cell-type) mask. 237 238 native_scale: native WSI pixels per 0.5 um/px patch pixel (= 0.5 / pixel_size_um). 239 spot_labels: (N, 3) array of (x, y, label) for each ST spot in native WSI coordinates. 240 spot_tree: pre-built cKDTree over spot_labels[:, :2]. Built locally if None (slow per-patch). 241 Each nucleus is assigned the label of its nearest ST spot via KDTree. 242 """ 243 try: 244 from shapely.geometry import box 245 from shapely.affinity import translate, scale as affine_scale 246 from rasterio.features import rasterize as rio_rasterize 247 from scipy.spatial import cKDTree 248 except ImportError: 249 raise ImportError("rasterio, shapely, and scipy are required. Install with: pip install geopandas rasterio scipy") # noqa 250 251 native_size = round(patch_size * native_scale) 252 patch_box = box(patch_x, patch_y, patch_x + native_size, patch_y + native_size) 253 local = cells_gdf[cells_gdf.geometry.intersects(patch_box)].copy() 254 mask = np.zeros((patch_size, patch_size), dtype=np.int32) 255 if local.empty: 256 return mask 257 258 # Assign each nucleus its nearest ST spot's label via KDTree on native coords. 259 tree = spot_tree if spot_tree is not None else cKDTree(spot_labels[:, :2]) 260 centroids = np.array([[g.centroid.x, g.centroid.y] for g in local.geometry]) 261 _, idx = tree.query(centroids) 262 local["label"] = spot_labels[idx, 2].astype(int) 263 264 inv = 1.0 / native_scale 265 local["geometry"] = local["geometry"].apply( 266 lambda g: affine_scale(translate(g, xoff=-patch_x, yoff=-patch_y), xfact=inv, yfact=inv, origin=(0, 0)) 267 ) 268 shapes = ((geom, int(label)) for geom, label in zip(local.geometry, local["label"])) 269 return rio_rasterize(shapes, out_shape=(patch_size, patch_size), fill=0, dtype=np.int32) 270 271 272def _load_panglaodb(cache_path: str) -> "pd.DataFrame": # noqa 273 """Download (once) and return the PanglaoDB marker-gene table.""" 274 try: 275 import pandas as pd 276 except ImportError: 277 raise ImportError("pandas is required. Install with: pip install pandas") 278 279 tsv_path = os.path.join(cache_path, "PanglaoDB_markers.tsv.gz") 280 if not os.path.exists(tsv_path): 281 import urllib.request 282 os.makedirs(cache_path, exist_ok=True) 283 req = urllib.request.Request(PANGLAODB_URL, headers={"User-Agent": "Mozilla/5.0"}) 284 with urllib.request.urlopen(req) as resp, open(tsv_path, "wb") as fh: 285 fh.write(resp.read()) 286 287 df = pd.read_csv(tsv_path, sep="\t") 288 # Keep only human genes. 289 df = df[df["species"].str.contains("Hs", na=False)] 290 return df[["official gene symbol", "cell type"]].copy() 291 292 293def _cell_type_to_category(cell_type_name: str) -> str: 294 """Map a PanglaoDB cell type name to one of the four coarse categories.""" 295 name = cell_type_name.lower() 296 for kw in EPITHELIAL_KEYWORDS: 297 if kw in name: 298 return "Epithelial" 299 for kw in INFLAMMATORY_KEYWORDS: 300 if kw in name: 301 return "Inflammatory" 302 for kw in CONNECTIVE_KEYWORDS: 303 if kw in name: 304 return "Connective" 305 return "Unknown" 306 307 308def _compute_cell_type_map( 309 h5ad_path: str, 310 marker_db: "pd.DataFrame", # noqa 311 top_n: int = 10, 312 tau_vote: int = 5, 313 top_m: int = 20, 314 tau_cancer: float = 0.25, 315) -> np.ndarray: 316 """Run the ST cell-type assignment pipeline from the paper (arXiv 2604.23481). 317 318 Returns an (N, 3) float32 array of (x, y, label) for each ST spot, where x and y 319 are native WSI pixel coordinates (pxl_col_in_fullres / pxl_row_in_fullres). Callers 320 use a KDTree to assign each segmented nucleus to its nearest ST spot. 321 """ 322 try: 323 import scanpy as sc 324 except ImportError: 325 raise ImportError("scanpy is required for semantic labels. Install with: pip install scanpy") 326 327 adata = sc.read_h5ad(h5ad_path) 328 329 if "pxl_col_in_fullres" not in adata.obs.columns or "pxl_row_in_fullres" not in adata.obs.columns: 330 raise ValueError("h5ad missing pxl_col_in_fullres / pxl_row_in_fullres spot coordinates.") 331 332 # Build gene -> category lookup from PanglaoDB. 333 gene_to_cats: Dict[str, List[str]] = {} 334 for gene, ct in zip(marker_db["official gene symbol"], marker_db["cell type"]): 335 cat = _cell_type_to_category(ct) 336 gene_to_cats.setdefault(gene, []).append(cat) 337 338 # Preprocessing and clustering. 339 sc.pp.normalize_total(adata, target_sum=1e4) 340 sc.pp.log1p(adata) 341 sc.pp.pca(adata) 342 sc.pp.neighbors(adata) 343 sc.tl.leiden(adata, resolution=4.0) 344 sc.tl.rank_genes_groups(adata, groupby="leiden", method="wilcoxon") 345 346 cluster_cat: Dict[str, str] = {} 347 for cluster in adata.obs["leiden"].unique(): 348 try: 349 top_genes = list( 350 sc.get.rank_genes_groups_df(adata, group=cluster)["names"].iloc[:top_m] 351 ) 352 except Exception: 353 cluster_cat[cluster] = "Unknown" 354 continue 355 356 votes: Dict[str, float] = {"Epithelial": 0.0, "Inflammatory": 0.0, "Connective": 0.0} 357 total_vote = 0.0 358 for rank, gene in enumerate(top_genes[:top_n]): 359 weight = top_n - rank 360 for cat in gene_to_cats.get(gene, []): 361 if cat in votes: 362 votes[cat] += weight 363 total_vote += weight 364 365 if total_vote < tau_vote: 366 cluster_cat[cluster] = "Unknown" 367 continue 368 369 best = max(votes, key=votes.get) # type: ignore[arg-type] 370 cluster_cat[cluster] = best 371 372 if best == "Epithelial": 373 cancer_overlap = sum(1 for g in top_genes[:top_m] if g in CANCER_GENES) 374 if cancer_overlap / top_m > tau_cancer: 375 cluster_cat[cluster] = "Neoplastic" 376 377 # Build (N, 3) array: (x, y, label) per ST spot in native WSI pixel coords. 378 xs = adata.obs["pxl_col_in_fullres"].values.astype(np.float32) 379 ys = adata.obs["pxl_row_in_fullres"].values.astype(np.float32) 380 labels = np.array( 381 [CELL_TYPE_LABELS[cluster_cat.get(adata.obs["leiden"].iloc[i], "Unknown")] 382 for i in range(adata.n_obs)], 383 dtype=np.float32, 384 ) 385 return np.stack([xs, ys, labels], axis=1) 386 387 388def _preprocess_sample( 389 patches_h5: str, 390 cellvit_geojson: Optional[str], 391 xenium_parquet: Optional[str], 392 h5ad_path: Optional[str], 393 marker_db: Optional["pd.DataFrame"], # noqa 394 out_h5: str, 395 patch_size: int = 224, 396 pixel_size_um: float = 0.5, 397) -> bool: 398 # Cell coords are in native WSI pixel space; patches are at 0.5 um/px. 399 # native_scale = native WSI pixels per 0.5 um/px patch pixel. 400 native_scale = 0.5 / pixel_size_um 401 402 with h5py.File(patches_h5, "r") as f: 403 img_key = "img" if "img" in f else ("imgs" if "imgs" in f else "images") 404 imgs = f[img_key][:] # (N, H, W, 3) uint8 405 coords = f["coords"][:] # (N, 2) top-left (x, y) in native WSI pixels 406 407 n = len(imgs) 408 if n == 0: 409 return False 410 411 # Load GeoDataFrames once per slide. 412 cellvit_gdf = None 413 if cellvit_geojson is not None and os.path.exists(cellvit_geojson): 414 cellvit_gdf = _gdf_from_cellvit_geojson(cellvit_geojson) 415 416 xenium_gdf = None 417 if xenium_parquet is not None and os.path.exists(xenium_parquet): 418 xenium_gdf = _gdf_from_xenium_parquet(xenium_parquet) 419 420 spot_labels: Optional[np.ndarray] = None 421 if h5ad_path is not None and os.path.exists(h5ad_path) and marker_db is not None and xenium_gdf is not None: 422 try: 423 spot_labels = _compute_cell_type_map(h5ad_path, marker_db) 424 except Exception as e: 425 print(f"Warning: semantic labels unavailable for {os.path.basename(h5ad_path)}: {e}") 426 427 # Build the KDTree once per slide rather than once per patch. 428 spot_tree = None 429 if spot_labels is not None: 430 try: 431 from scipy.spatial import cKDTree 432 spot_tree = cKDTree(spot_labels[:, :2]) 433 except ImportError: 434 pass 435 436 raw = np.zeros((n, 3, patch_size, patch_size), dtype=np.uint8) 437 instances = np.zeros((n, patch_size, patch_size), dtype=np.int32) 438 xenium_instances = np.zeros((n, patch_size, patch_size), dtype=np.int32) 439 semantic = np.zeros((n, patch_size, patch_size), dtype=np.int32) 440 441 sid = os.path.splitext(os.path.basename(out_h5))[0] 442 for i, (img, coord) in enumerate(tqdm(zip(imgs, coords), total=n, desc=f"Processing {sid}", leave=False)): 443 raw[i] = img[:patch_size, :patch_size, :].transpose(2, 0, 1) 444 px, py = int(coord[0]), int(coord[1]) 445 446 if cellvit_gdf is not None: 447 instances[i] = _rasterize_patch_instances(px, py, patch_size, cellvit_gdf, native_scale) 448 449 if xenium_gdf is not None: 450 xenium_instances[i] = _rasterize_patch_instances(px, py, patch_size, xenium_gdf, native_scale) 451 452 if spot_labels is not None and xenium_gdf is not None: 453 semantic[i] = _rasterize_patch_semantic( 454 px, py, patch_size, xenium_gdf, spot_labels, native_scale, spot_tree 455 ) 456 457 chunk_2d = (1, patch_size, patch_size) 458 with h5py.File(out_h5, "w") as f: 459 f.create_dataset("raw", data=raw, compression="gzip", chunks=(1, 3, patch_size, patch_size)) 460 f.create_dataset(LABEL_KEYS["instances"], data=instances, compression="gzip", chunks=chunk_2d) 461 f.create_dataset(LABEL_KEYS["xenium_instances"], data=xenium_instances, compression="gzip", chunks=chunk_2d) 462 f.create_dataset(LABEL_KEYS["semantic"], data=semantic, compression="gzip", chunks=chunk_2d) 463 464 return True 465 466 467class HESTDataset(Dataset): 468 """2D patch dataset for HEST-1k. 469 470 Indexes all patches across all per-slide H5 files and returns proper 2D tensors: 471 raw (3, H, W) float32 in [0, 1] and labels (H, W) int32. 472 """ 473 474 def __init__( 475 self, 476 h5_paths: List[str], 477 label_key: str, 478 raw_transform: Optional[Callable] = None, 479 label_transform: Optional[Callable] = None, 480 transform: Optional[Callable] = None, 481 n_samples: Optional[int] = None, 482 seed: Optional[int] = None, 483 ): 484 self._label_key = label_key 485 self._raw_transform = raw_transform 486 self._label_transform = label_transform 487 self._transform = transform 488 489 # Build flat index: list of (h5_path, patch_idx). 490 self._index: List[Tuple[str, int]] = [] 491 for h5_path in h5_paths: 492 with h5py.File(h5_path, "r") as f: 493 n = f["raw"].shape[0] # raw stored as (N, 3, H, W) 494 self._index.extend((h5_path, i) for i in range(n)) 495 496 if n_samples is not None: 497 rng = np.random.default_rng(seed) 498 chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index)) 499 self._index = [self._index[i] for i in chosen] 500 501 def __len__(self) -> int: 502 return len(self._index) 503 504 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 505 h5_path, patch_idx = self._index[idx] 506 with h5py.File(h5_path, "r") as f: 507 raw = f["raw"][patch_idx].astype(np.float32) / 255.0 # (3, H, W) 508 label = f[self._label_key][patch_idx].astype(np.int32) # (H, W) 509 510 raw = torch.from_numpy(raw) 511 label = torch.from_numpy(label) 512 513 if self._raw_transform is not None: 514 raw = self._raw_transform(raw) 515 if self._label_transform is not None: 516 label = self._label_transform(label) 517 if self._transform is not None: 518 raw, label = self._transform(raw, label) 519 520 return raw, label 521 522 523def get_hest_data( 524 path: Union[os.PathLike, str], 525 organs: Optional[List[str]] = None, 526 download: bool = False, 527) -> str: 528 """Download and preprocess the HEST-1k dataset. 529 530 Args: 531 path: Filepath to a folder where the downloaded data will be saved. 532 organs: List of organ types to include. Uses all available organs if None. 533 Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481. 534 download: Whether to download the data if it is not present. 535 536 Returns: 537 The filepath to the preprocessed data directory. 538 """ 539 preprocessed_dir = os.path.join(path, "preprocessed") 540 541 if download: 542 meta_path = os.path.join(path, METADATA_FILENAME) 543 if not os.path.exists(meta_path): 544 try: 545 from huggingface_hub import hf_hub_download 546 except ImportError: 547 raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub") 548 hf_hub_download(repo_id=HF_REPO, repo_type="dataset", filename=METADATA_FILENAME, local_dir=path) 549 550 sample_ids = _filter_sample_ids(path, organs) 551 xenium_dir = os.path.join(path, "xenium_seg") 552 st_dir = os.path.join(path, "st") 553 include_xenium = not os.path.exists(xenium_dir) 554 include_st = not os.path.exists(st_dir) 555 _download_hest(path, sample_ids, include_xenium=include_xenium, include_st=include_st) 556 else: 557 sample_ids = [ 558 os.path.splitext(os.path.basename(p))[0] 559 for p in glob(os.path.join(path, "patches", "*.h5")) 560 ] 561 if organs is not None: 562 meta_path = os.path.join(path, METADATA_FILENAME) 563 if os.path.exists(meta_path): 564 allowed = set(_filter_sample_ids(path, organs)) 565 sample_ids = [s for s in sample_ids if s in allowed] 566 567 # Load PanglaoDB once for all samples. 568 db_cache = os.path.join(path, "_db_cache") 569 try: 570 marker_db = _load_panglaodb(db_cache) 571 except Exception: 572 marker_db = None 573 574 # Build a pixel_size lookup from the metadata (um/px at native resolution). 575 try: 576 meta = _load_metadata(path) 577 pixel_size_map = dict(zip(meta["id"], meta["pixel_size_um_estimated"].fillna(0.5))) 578 except Exception: 579 pixel_size_map = {} 580 581 os.makedirs(preprocessed_dir, exist_ok=True) 582 cellvit_zip_dir = os.path.join(path, "cellvit_seg") 583 cellvit_cache = os.path.join(path, "_cellvit_extracted") 584 xenium_dir = os.path.join(path, "xenium_seg") 585 st_dir = os.path.join(path, "st") 586 587 for sid in tqdm(sample_ids, desc="Preprocessing HEST samples"): 588 out_h5 = os.path.join(preprocessed_dir, f"{sid}.h5") 589 if os.path.exists(out_h5): 590 continue 591 592 patches_h5 = os.path.join(path, "patches", f"{sid}.h5") 593 if not os.path.exists(patches_h5): 594 continue 595 596 geojson_path = _unzip_cellvit( 597 os.path.join(cellvit_zip_dir, f"{sid}_cellvit_seg.geojson.zip"), cellvit_cache 598 ) 599 xenium_parquet = os.path.join(xenium_dir, f"{sid}_xenium_nucleus_seg.parquet") 600 h5ad_path = os.path.join(st_dir, f"{sid}.h5ad") 601 pixel_size_um = float(pixel_size_map.get(sid, 0.5)) 602 603 _preprocess_sample( 604 patches_h5=patches_h5, 605 cellvit_geojson=geojson_path, 606 xenium_parquet=xenium_parquet if os.path.exists(xenium_parquet) else None, 607 h5ad_path=h5ad_path if os.path.exists(h5ad_path) else None, 608 marker_db=marker_db, 609 out_h5=out_h5, 610 pixel_size_um=pixel_size_um, 611 ) 612 613 return preprocessed_dir 614 615 616def get_hest_paths( 617 path: Union[os.PathLike, str], 618 organs: Optional[List[str]] = None, 619 download: bool = False, 620) -> List[str]: 621 """Get paths to the preprocessed HEST-1k H5 files. 622 623 Args: 624 path: Filepath to a folder where the downloaded data will be saved. 625 organs: List of organ types to include. Uses all available organs if None. 626 download: Whether to download the data if it is not present. 627 628 Returns: 629 List of filepaths to the preprocessed H5 files (one per slide). 630 """ 631 preprocessed_dir = get_hest_data(path, organs, download) 632 h5_paths = natsorted(glob(os.path.join(preprocessed_dir, "*.h5"))) 633 if not h5_paths: 634 raise RuntimeError(f"No preprocessed data found in {preprocessed_dir}.") 635 636 if organs is not None: 637 meta_path = os.path.join(path, METADATA_FILENAME) 638 if os.path.exists(meta_path): 639 allowed = set(_filter_sample_ids(path, organs)) 640 h5_paths = [p for p in h5_paths if os.path.splitext(os.path.basename(p))[0] in allowed] 641 642 return h5_paths 643 644 645def get_hest_dataset( 646 path: Union[os.PathLike, str], 647 patch_shape: Tuple[int, int], 648 organs: Optional[List[str]] = None, 649 label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances", 650 download: bool = False, 651 n_samples: Optional[int] = None, 652 seed: Optional[int] = None, 653 raw_transform: Optional[Callable] = None, 654 label_transform: Optional[Callable] = None, 655 transform: Optional[Callable] = None, 656) -> Dataset: 657 """Get the HEST-1k dataset for nuclei segmentation and cell-type classification. 658 659 Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32. 660 661 Args: 662 path: Filepath to a folder where the downloaded data will be saved. 663 patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency. 664 organs: List of organ types to include. Uses all available organs if None. 665 Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481. 666 label_choice: Which label type to return: 667 - 'instances': CellViT nuclei instance masks (H&E-based, all samples). 668 - 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise). 669 - 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise). 670 download: Whether to download the data if it is not present. 671 n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None. 672 seed: Random seed for reproducible patch sampling when n_samples is set. 673 raw_transform: Transform applied to the raw image tensor. 674 label_transform: Transform applied to the label tensor. 675 transform: Joint transform applied to both raw and label. 676 677 Returns: 678 The segmentation dataset. 679 """ 680 valid = ("instances", "xenium_instances", "semantic") 681 if label_choice not in valid: 682 raise ValueError(f"'{label_choice}' is not valid. Choose from {valid}.") 683 684 h5_paths = get_hest_paths(path, organs, download) 685 return HESTDataset( 686 h5_paths=h5_paths, 687 label_key=LABEL_KEYS[label_choice], 688 raw_transform=raw_transform, 689 label_transform=label_transform, 690 transform=transform, 691 n_samples=n_samples, 692 seed=seed, 693 ) 694 695 696def get_hest_loader( 697 path: Union[os.PathLike, str], 698 batch_size: int, 699 patch_shape: Tuple[int, int], 700 organs: Optional[List[str]] = None, 701 label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances", 702 download: bool = False, 703 n_samples: Optional[int] = None, 704 seed: Optional[int] = None, 705 raw_transform: Optional[Callable] = None, 706 label_transform: Optional[Callable] = None, 707 transform: Optional[Callable] = None, 708 **loader_kwargs, 709) -> DataLoader: 710 """Get the HEST-1k dataloader for nuclei segmentation and cell-type classification. 711 712 Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32. 713 714 Args: 715 path: Filepath to a folder where the downloaded data will be saved. 716 batch_size: The batch size for training. 717 patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency. 718 organs: List of organ types to include. Uses all available organs if None. 719 Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481. 720 label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'. 721 download: Whether to download the data if it is not present. 722 n_samples: Number of patches per epoch. Uses all patches if None. 723 seed: Random seed for reproducible patch sampling when n_samples is set. 724 raw_transform: Transform applied to the raw image tensor. 725 label_transform: Transform applied to the label tensor. 726 transform: Joint transform applied to both raw and label. 727 loader_kwargs: Additional keyword arguments for the PyTorch DataLoader. 728 729 Returns: 730 The DataLoader. 731 """ 732 dataset = get_hest_dataset( 733 path, patch_shape, organs, label_choice, download, n_samples, seed, raw_transform, label_transform, transform 734 ) 735 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
468class HESTDataset(Dataset): 469 """2D patch dataset for HEST-1k. 470 471 Indexes all patches across all per-slide H5 files and returns proper 2D tensors: 472 raw (3, H, W) float32 in [0, 1] and labels (H, W) int32. 473 """ 474 475 def __init__( 476 self, 477 h5_paths: List[str], 478 label_key: str, 479 raw_transform: Optional[Callable] = None, 480 label_transform: Optional[Callable] = None, 481 transform: Optional[Callable] = None, 482 n_samples: Optional[int] = None, 483 seed: Optional[int] = None, 484 ): 485 self._label_key = label_key 486 self._raw_transform = raw_transform 487 self._label_transform = label_transform 488 self._transform = transform 489 490 # Build flat index: list of (h5_path, patch_idx). 491 self._index: List[Tuple[str, int]] = [] 492 for h5_path in h5_paths: 493 with h5py.File(h5_path, "r") as f: 494 n = f["raw"].shape[0] # raw stored as (N, 3, H, W) 495 self._index.extend((h5_path, i) for i in range(n)) 496 497 if n_samples is not None: 498 rng = np.random.default_rng(seed) 499 chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index)) 500 self._index = [self._index[i] for i in chosen] 501 502 def __len__(self) -> int: 503 return len(self._index) 504 505 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 506 h5_path, patch_idx = self._index[idx] 507 with h5py.File(h5_path, "r") as f: 508 raw = f["raw"][patch_idx].astype(np.float32) / 255.0 # (3, H, W) 509 label = f[self._label_key][patch_idx].astype(np.int32) # (H, W) 510 511 raw = torch.from_numpy(raw) 512 label = torch.from_numpy(label) 513 514 if self._raw_transform is not None: 515 raw = self._raw_transform(raw) 516 if self._label_transform is not None: 517 label = self._label_transform(label) 518 if self._transform is not None: 519 raw, label = self._transform(raw, label) 520 521 return raw, label
2D patch dataset for HEST-1k.
Indexes all patches across all per-slide H5 files and returns proper 2D tensors: raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
475 def __init__( 476 self, 477 h5_paths: List[str], 478 label_key: str, 479 raw_transform: Optional[Callable] = None, 480 label_transform: Optional[Callable] = None, 481 transform: Optional[Callable] = None, 482 n_samples: Optional[int] = None, 483 seed: Optional[int] = None, 484 ): 485 self._label_key = label_key 486 self._raw_transform = raw_transform 487 self._label_transform = label_transform 488 self._transform = transform 489 490 # Build flat index: list of (h5_path, patch_idx). 491 self._index: List[Tuple[str, int]] = [] 492 for h5_path in h5_paths: 493 with h5py.File(h5_path, "r") as f: 494 n = f["raw"].shape[0] # raw stored as (N, 3, H, W) 495 self._index.extend((h5_path, i) for i in range(n)) 496 497 if n_samples is not None: 498 rng = np.random.default_rng(seed) 499 chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index)) 500 self._index = [self._index[i] for i in chosen]
524def get_hest_data( 525 path: Union[os.PathLike, str], 526 organs: Optional[List[str]] = None, 527 download: bool = False, 528) -> str: 529 """Download and preprocess the HEST-1k dataset. 530 531 Args: 532 path: Filepath to a folder where the downloaded data will be saved. 533 organs: List of organ types to include. Uses all available organs if None. 534 Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481. 535 download: Whether to download the data if it is not present. 536 537 Returns: 538 The filepath to the preprocessed data directory. 539 """ 540 preprocessed_dir = os.path.join(path, "preprocessed") 541 542 if download: 543 meta_path = os.path.join(path, METADATA_FILENAME) 544 if not os.path.exists(meta_path): 545 try: 546 from huggingface_hub import hf_hub_download 547 except ImportError: 548 raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub") 549 hf_hub_download(repo_id=HF_REPO, repo_type="dataset", filename=METADATA_FILENAME, local_dir=path) 550 551 sample_ids = _filter_sample_ids(path, organs) 552 xenium_dir = os.path.join(path, "xenium_seg") 553 st_dir = os.path.join(path, "st") 554 include_xenium = not os.path.exists(xenium_dir) 555 include_st = not os.path.exists(st_dir) 556 _download_hest(path, sample_ids, include_xenium=include_xenium, include_st=include_st) 557 else: 558 sample_ids = [ 559 os.path.splitext(os.path.basename(p))[0] 560 for p in glob(os.path.join(path, "patches", "*.h5")) 561 ] 562 if organs is not None: 563 meta_path = os.path.join(path, METADATA_FILENAME) 564 if os.path.exists(meta_path): 565 allowed = set(_filter_sample_ids(path, organs)) 566 sample_ids = [s for s in sample_ids if s in allowed] 567 568 # Load PanglaoDB once for all samples. 569 db_cache = os.path.join(path, "_db_cache") 570 try: 571 marker_db = _load_panglaodb(db_cache) 572 except Exception: 573 marker_db = None 574 575 # Build a pixel_size lookup from the metadata (um/px at native resolution). 576 try: 577 meta = _load_metadata(path) 578 pixel_size_map = dict(zip(meta["id"], meta["pixel_size_um_estimated"].fillna(0.5))) 579 except Exception: 580 pixel_size_map = {} 581 582 os.makedirs(preprocessed_dir, exist_ok=True) 583 cellvit_zip_dir = os.path.join(path, "cellvit_seg") 584 cellvit_cache = os.path.join(path, "_cellvit_extracted") 585 xenium_dir = os.path.join(path, "xenium_seg") 586 st_dir = os.path.join(path, "st") 587 588 for sid in tqdm(sample_ids, desc="Preprocessing HEST samples"): 589 out_h5 = os.path.join(preprocessed_dir, f"{sid}.h5") 590 if os.path.exists(out_h5): 591 continue 592 593 patches_h5 = os.path.join(path, "patches", f"{sid}.h5") 594 if not os.path.exists(patches_h5): 595 continue 596 597 geojson_path = _unzip_cellvit( 598 os.path.join(cellvit_zip_dir, f"{sid}_cellvit_seg.geojson.zip"), cellvit_cache 599 ) 600 xenium_parquet = os.path.join(xenium_dir, f"{sid}_xenium_nucleus_seg.parquet") 601 h5ad_path = os.path.join(st_dir, f"{sid}.h5ad") 602 pixel_size_um = float(pixel_size_map.get(sid, 0.5)) 603 604 _preprocess_sample( 605 patches_h5=patches_h5, 606 cellvit_geojson=geojson_path, 607 xenium_parquet=xenium_parquet if os.path.exists(xenium_parquet) else None, 608 h5ad_path=h5ad_path if os.path.exists(h5ad_path) else None, 609 marker_db=marker_db, 610 out_h5=out_h5, 611 pixel_size_um=pixel_size_um, 612 ) 613 614 return preprocessed_dir
Download and preprocess the HEST-1k dataset.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- organs: List of organ types to include. Uses all available organs if None. Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481.
- download: Whether to download the data if it is not present.
Returns:
The filepath to the preprocessed data directory.
617def get_hest_paths( 618 path: Union[os.PathLike, str], 619 organs: Optional[List[str]] = None, 620 download: bool = False, 621) -> List[str]: 622 """Get paths to the preprocessed HEST-1k H5 files. 623 624 Args: 625 path: Filepath to a folder where the downloaded data will be saved. 626 organs: List of organ types to include. Uses all available organs if None. 627 download: Whether to download the data if it is not present. 628 629 Returns: 630 List of filepaths to the preprocessed H5 files (one per slide). 631 """ 632 preprocessed_dir = get_hest_data(path, organs, download) 633 h5_paths = natsorted(glob(os.path.join(preprocessed_dir, "*.h5"))) 634 if not h5_paths: 635 raise RuntimeError(f"No preprocessed data found in {preprocessed_dir}.") 636 637 if organs is not None: 638 meta_path = os.path.join(path, METADATA_FILENAME) 639 if os.path.exists(meta_path): 640 allowed = set(_filter_sample_ids(path, organs)) 641 h5_paths = [p for p in h5_paths if os.path.splitext(os.path.basename(p))[0] in allowed] 642 643 return h5_paths
Get paths to the preprocessed HEST-1k H5 files.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- organs: List of organ types to include. Uses all available organs if None.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths to the preprocessed H5 files (one per slide).
646def get_hest_dataset( 647 path: Union[os.PathLike, str], 648 patch_shape: Tuple[int, int], 649 organs: Optional[List[str]] = None, 650 label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances", 651 download: bool = False, 652 n_samples: Optional[int] = None, 653 seed: Optional[int] = None, 654 raw_transform: Optional[Callable] = None, 655 label_transform: Optional[Callable] = None, 656 transform: Optional[Callable] = None, 657) -> Dataset: 658 """Get the HEST-1k dataset for nuclei segmentation and cell-type classification. 659 660 Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32. 661 662 Args: 663 path: Filepath to a folder where the downloaded data will be saved. 664 patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency. 665 organs: List of organ types to include. Uses all available organs if None. 666 Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481. 667 label_choice: Which label type to return: 668 - 'instances': CellViT nuclei instance masks (H&E-based, all samples). 669 - 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise). 670 - 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise). 671 download: Whether to download the data if it is not present. 672 n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None. 673 seed: Random seed for reproducible patch sampling when n_samples is set. 674 raw_transform: Transform applied to the raw image tensor. 675 label_transform: Transform applied to the label tensor. 676 transform: Joint transform applied to both raw and label. 677 678 Returns: 679 The segmentation dataset. 680 """ 681 valid = ("instances", "xenium_instances", "semantic") 682 if label_choice not in valid: 683 raise ValueError(f"'{label_choice}' is not valid. Choose from {valid}.") 684 685 h5_paths = get_hest_paths(path, organs, download) 686 return HESTDataset( 687 h5_paths=h5_paths, 688 label_key=LABEL_KEYS[label_choice], 689 raw_transform=raw_transform, 690 label_transform=label_transform, 691 transform=transform, 692 n_samples=n_samples, 693 seed=seed, 694 )
Get the HEST-1k dataset for nuclei segmentation and cell-type classification.
Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
- organs: List of organ types to include. Uses all available organs if None. Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
- label_choice: Which label type to return:
- 'instances': CellViT nuclei instance masks (H&E-based, all samples).
- 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise).
- 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise).
- download: Whether to download the data if it is not present.
- n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None.
- seed: Random seed for reproducible patch sampling when n_samples is set.
- raw_transform: Transform applied to the raw image tensor.
- label_transform: Transform applied to the label tensor.
- transform: Joint transform applied to both raw and label.
Returns:
The segmentation dataset.
697def get_hest_loader( 698 path: Union[os.PathLike, str], 699 batch_size: int, 700 patch_shape: Tuple[int, int], 701 organs: Optional[List[str]] = None, 702 label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances", 703 download: bool = False, 704 n_samples: Optional[int] = None, 705 seed: Optional[int] = None, 706 raw_transform: Optional[Callable] = None, 707 label_transform: Optional[Callable] = None, 708 transform: Optional[Callable] = None, 709 **loader_kwargs, 710) -> DataLoader: 711 """Get the HEST-1k dataloader for nuclei segmentation and cell-type classification. 712 713 Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32. 714 715 Args: 716 path: Filepath to a folder where the downloaded data will be saved. 717 batch_size: The batch size for training. 718 patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency. 719 organs: List of organ types to include. Uses all available organs if None. 720 Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481. 721 label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'. 722 download: Whether to download the data if it is not present. 723 n_samples: Number of patches per epoch. Uses all patches if None. 724 seed: Random seed for reproducible patch sampling when n_samples is set. 725 raw_transform: Transform applied to the raw image tensor. 726 label_transform: Transform applied to the label tensor. 727 transform: Joint transform applied to both raw and label. 728 loader_kwargs: Additional keyword arguments for the PyTorch DataLoader. 729 730 Returns: 731 The DataLoader. 732 """ 733 dataset = get_hest_dataset( 734 path, patch_shape, organs, label_choice, download, n_samples, seed, raw_transform, label_transform, transform 735 ) 736 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the HEST-1k dataloader for nuclei segmentation and cell-type classification.
Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- batch_size: The batch size for training.
- patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
- organs: List of organ types to include. Uses all available organs if None. Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
- label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'.
- download: Whether to download the data if it is not present.
- n_samples: Number of patches per epoch. Uses all patches if None.
- seed: Random seed for reproducible patch sampling when n_samples is set.
- raw_transform: Transform applied to the raw image tensor.
- label_transform: Transform applied to the label tensor.
- transform: Joint transform applied to both raw and label.
- loader_kwargs: Additional keyword arguments for the PyTorch DataLoader.
Returns:
The DataLoader.