torch_em.data.datasets.histopathology.hest

The HEST-1k dataset contains 1,276 paired H&E whole-slide images and spatial transcriptomics profiles across 26 organ types. For each sample, pre-extracted 224x224 H&E patches at 0.5 um/px, CellViT nuclei instance segmentation masks, Xenium DAPI-derived nucleus boundaries (for Xenium samples), and cell-level spatial transcriptomics gene expression profiles are available.

Three types of segmentation labels are supported:

  • 'instances': nuclei instance masks derived from CellViT (H&E-based, all samples).
  • 'xenium_instances': nuclei instance masks from DAPI segmentation (Xenium samples only).
  • 'semantic': cell-type semantic masks derived from spatial transcriptomics via Leiden clustering and PanglaoDB marker-gene voting (Xenium samples only). Classes: 0=background, 1=Epithelial, 2=Inflammatory, 3=Connective, 4=Neoplastic, 5=Unknown.

This dataset is used in the paper https://doi.org/10.48550/arXiv.2604.23481 as a scalable alternative to manually annotated datasets for nuclei segmentation and classification training.

The dataset is located at https://huggingface.co/datasets/MahmoodLab/hest. This dataset is from the following publication:

NOTE: Requires huggingface_hub for download: pip install huggingface_hub NOTE: Requires geopandas, rasterio, and scipy for preprocessing: pip install geopandas rasterio scipy NOTE: Requires scanpy, python-igraph, and leidenalg for semantic labels: pip install scanpy igraph leidenalg NOTE: The full dataset is ~2 TB. Use the organs argument to download only a subset.

  1"""The HEST-1k dataset contains 1,276 paired H&E whole-slide images and spatial transcriptomics
  2profiles across 26 organ types. For each sample, pre-extracted 224x224 H&E patches at 0.5 um/px,
  3CellViT nuclei instance segmentation masks, Xenium DAPI-derived nucleus boundaries (for Xenium
  4samples), and cell-level spatial transcriptomics gene expression profiles are available.
  5
  6Three types of segmentation labels are supported:
  7- 'instances': nuclei instance masks derived from CellViT (H&E-based, all samples).
  8- 'xenium_instances': nuclei instance masks from DAPI segmentation (Xenium samples only).
  9- 'semantic': cell-type semantic masks derived from spatial transcriptomics via Leiden clustering
 10  and PanglaoDB marker-gene voting (Xenium samples only). Classes: 0=background, 1=Epithelial,
 11  2=Inflammatory, 3=Connective, 4=Neoplastic, 5=Unknown.
 12
 13This dataset is used in the paper https://doi.org/10.48550/arXiv.2604.23481 as a scalable
 14alternative to manually annotated datasets for nuclei segmentation and classification training.
 15
 16The dataset is located at https://huggingface.co/datasets/MahmoodLab/hest.
 17This dataset is from the following publication:
 18- Jaume et al. (2024): https://doi.org/10.48550/arXiv.2406.16192
 19Please cite it if you use this dataset in your research.
 20
 21NOTE: Requires huggingface_hub for download: pip install huggingface_hub
 22NOTE: Requires geopandas, rasterio, and scipy for preprocessing: pip install geopandas rasterio scipy
 23NOTE: Requires scanpy, python-igraph, and leidenalg for semantic labels: pip install scanpy igraph leidenalg
 24NOTE: The full dataset is ~2 TB. Use the `organs` argument to download only a subset.
 25"""
 26
 27import json
 28import os
 29import zipfile
 30from glob import glob
 31from tqdm import tqdm
 32from natsort import natsorted
 33from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
 34
 35import h5py
 36import numpy as np
 37import torch
 38from torch.utils.data import Dataset, DataLoader
 39
 40import torch_em
 41
 42
 43HF_REPO = "MahmoodLab/hest"
 44METADATA_FILENAME = "HEST_v1_3_0.csv"
 45PANGLAODB_URL = "https://panglaodb.se/markers/PanglaoDB_markers_27_Mar_2020.tsv.gz"
 46
 47# Integer label for each cell-type category (0 = background).
 48CELL_TYPE_LABELS = {"Epithelial": 1, "Inflammatory": 2, "Connective": 3, "Neoplastic": 4, "Unknown": 5}
 49
 50# Map from public label_choice strings to HDF5 dataset paths.
 51LABEL_KEYS = {
 52    "instances": "labels/instances/h&e",
 53    "xenium_instances": "labels/instances/xenium",
 54    "semantic": "labels/semantic/st",
 55}
 56
 57# Organs present in both HEST-1k and PanNuke (used in arXiv 2604.23481).
 58PANNUKE_ORGANS = [
 59    "Breast", "Colon", "Kidney", "Liver", "Lung", "Ovarian", "Pancreatic", "Prostate", "Skin", "Stomach",
 60]
 61
 62# Keyword fragments for mapping PanglaoDB cell type names to coarse categories.
 63EPITHELIAL_KEYWORDS = [
 64    "acinar", "airway epithelial", "airway goblet", "alveolar type", "alpha cell", "basal cell",
 65    "beta cell", "cholangiocyte", "ciliated", "clara", "crypt", "delta cell", "ductal",
 66    "enterocyte", "epithelial", "goblet", "hepatocyte", "keratinocyte", "mesothelial",
 67    "paneth", "pneumocyte", "proximal tubule", "renal tubule", "squamous", "thyroid",
 68    "trophoblast", "tuft", "urothelial",
 69]
 70INFLAMMATORY_KEYWORDS = [
 71    "alveolar macrophage", "b cell", "basophil", "dendritic", "eosinophil",
 72    "innate lymphoid", "lymphocyte", "macrophage", "mast cell", "monocyte",
 73    "natural killer", "neutrophil", "nk cell", "plasma cell", "regulatory t", "t cell",
 74]
 75CONNECTIVE_KEYWORDS = [
 76    "adipocyte", "chondrocyte", "endothelial", "fibroblast", "mesenchymal",
 77    "myofibroblast", "osteoblast", "osteoclast", "pericyte", "smooth muscle",
 78    "stellate", "stromal", "vascular",
 79]
 80
 81# Well-known cancer-associated genes (COSMIC Cancer Gene Census, tier 1).
 82CANCER_GENES = {
 83    "ABL1", "AKT1", "ALK", "APC", "ATM", "BRAF", "BRCA1", "BRCA2", "CDH1", "CDKN2A",
 84    "CTNNB1", "EGFR", "ERBB2", "ESR1", "EZH2", "FBXW7", "FGFR1", "FGFR2", "FGFR3",
 85    "FLT3", "GATA3", "GNAQ", "GNAS", "HNF1A", "HRAS", "IDH1", "IDH2", "JAK2", "KIT",
 86    "KRAS", "MAP2K1", "MDM2", "MET", "MLH1", "MSH2", "MSH6", "MTOR", "MYC", "MYCN",
 87    "NF1", "NF2", "NFE2L2", "NOTCH1", "NOTCH2", "NRAS", "PALB2", "PBRM1", "PIK3CA",
 88    "PIK3R1", "PMS2", "POLE", "PTCH1", "PTEN", "RB1", "RET", "RNF43", "SETD2", "SF3B1",
 89    "SMAD4", "SMARCA4", "SMARCB1", "SMO", "STK11", "TERT", "TET2", "TP53", "TSC1",
 90    "TSC2", "VHL", "BAP1", "CDK12", "CHEK2", "CREBBP", "DNMT3A", "EP300", "FANCD2",
 91    "KDM5C", "KDM6A", "KEAP1", "MAP3K1", "MUTYH", "NBN", "PDGFRA", "PPP2R1A", "RAD51C",
 92    "RUNX1", "SDHA", "SDHB", "SDHC", "SDHD", "SUFU", "TP63", "XRCC2", "AXIN1", "AXIN2",
 93    "BRIP1", "CHD4", "ELOC", "FANCA", "FH", "FLCN", "MRE11", "RAD50", "RAD51B", "RAD51D",
 94}
 95
 96
 97def _download_hest(path: str, sample_ids: List[str], include_xenium: bool, include_st: bool) -> None:
 98    try:
 99        from huggingface_hub import snapshot_download
100    except ImportError:
101        raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
102
103    patterns = [METADATA_FILENAME]
104    for sid in sample_ids:
105        patterns += [f"patches/{sid}.h5", f"cellvit_seg/{sid}_cellvit_seg.geojson.zip"]
106        if include_xenium:
107            patterns += [f"xenium_seg/{sid}_xenium_nucleus_seg.parquet"]
108        if include_st:
109            patterns += [f"st/{sid}.h5ad"]
110
111    os.makedirs(path, exist_ok=True)
112    snapshot_download(repo_id=HF_REPO, repo_type="dataset", local_dir=path, allow_patterns=patterns)
113
114
115def _load_metadata(path: str) -> "pd.DataFrame":  # noqa
116    try:
117        import pandas as pd
118    except ImportError:
119        raise ImportError("pandas is required. Install with: pip install pandas")
120
121    csv_path = os.path.join(path, METADATA_FILENAME)
122    if not os.path.exists(csv_path):
123        raise RuntimeError(f"Metadata not found at {csv_path}. Run get_hest_data() first.")
124    return pd.read_csv(csv_path)
125
126
127def _filter_sample_ids(path: str, organs: Optional[List[str]]) -> List[str]:
128    meta = _load_metadata(path)
129    if organs is not None:
130        meta = meta[meta["organ"].isin(organs)]
131    return meta["id"].tolist()
132
133
134def _unzip_cellvit(zip_path: str, out_dir: str) -> Optional[str]:
135    if not os.path.exists(zip_path):
136        return None
137    # Strip both extensions: "SAMPLEID_cellvit_seg.geojson.zip" -> "SAMPLEID"
138    sample_id = os.path.basename(zip_path).replace("_cellvit_seg.geojson.zip", "")
139    extract_dir = os.path.join(out_dir, sample_id)
140    if not os.path.exists(extract_dir):
141        with zipfile.ZipFile(zip_path, "r") as zf:
142            zf.extractall(extract_dir)
143    matches = glob(os.path.join(extract_dir, "**", "*.geojson"), recursive=True)
144    return matches[0] if matches else None
145
146
147def _gdf_from_xenium_parquet(parquet_path: str) -> "gpd.GeoDataFrame":  # noqa
148    """Load a Xenium nucleus segmentation parquet into a GeoDataFrame.
149
150    Expected format: index is cell_id, single 'geometry' column with WKB-encoded polygons.
151    """
152    try:
153        import pandas as pd
154        import geopandas as gpd
155        import shapely
156    except ImportError:
157        raise ImportError("geopandas and shapely are required. Install with: pip install geopandas rasterio")
158
159    df = pd.read_parquet(parquet_path)
160    geometries = shapely.from_wkb(df["geometry"].values)
161    return gpd.GeoDataFrame({"cell_id": df.index.astype(str), "geometry": geometries}, geometry="geometry")
162
163
164def _gdf_from_cellvit_geojson(geojson_path: str) -> "gpd.GeoDataFrame":  # noqa
165    """Load CellViT segmentation GeoJSON into a GeoDataFrame with one row per nucleus.
166
167    The file is a JSON list of features with MultiPolygon geometries (one per cell-type class).
168    Each MultiPolygon is exploded into individual Polygon rows.
169    """
170    try:
171        import geopandas as gpd
172        from shapely.geometry import shape, MultiPolygon
173    except ImportError:
174        raise ImportError("geopandas and shapely are required. Install with: pip install geopandas rasterio")
175
176    with open(geojson_path) as fh:
177        data = json.load(fh)
178
179    records = []
180    for feat in data:
181        geom = shape(feat["geometry"])
182        if isinstance(geom, MultiPolygon):
183            for poly in geom.geoms:
184                records.append({"geometry": poly})
185        else:
186            records.append({"geometry": geom})
187
188    if not records:
189        return gpd.GeoDataFrame(columns=["geometry"])
190    return gpd.GeoDataFrame(records, geometry="geometry")
191
192
193def _rasterize_patch_instances(
194    patch_x: int,
195    patch_y: int,
196    patch_size: int,
197    cells_gdf: "gpd.GeoDataFrame",  # noqa
198    native_scale: float = 1.0,
199) -> np.ndarray:
200    """Rasterize nucleus polygons within one patch to an instance mask.
201
202    native_scale: native WSI pixels per 0.5 um/px patch pixel (= 0.5 / pixel_size_um).
203    Patches are stored at 0.5 um/px but cell coords are in native WSI pixel space.
204    """
205    try:
206        from shapely.geometry import box
207        from shapely.affinity import translate, scale as affine_scale
208        from rasterio.features import rasterize as rio_rasterize
209    except ImportError:
210        raise ImportError("rasterio and shapely are required. Install with: pip install geopandas rasterio")
211
212    native_size = round(patch_size * native_scale)
213    patch_box = box(patch_x, patch_y, patch_x + native_size, patch_y + native_size)
214    local = cells_gdf[cells_gdf.geometry.intersects(patch_box)].copy()
215    mask = np.zeros((patch_size, patch_size), dtype=np.int32)
216    if local.empty:
217        return mask
218
219    inv = 1.0 / native_scale
220    local["geometry"] = local["geometry"].apply(
221        lambda g: affine_scale(translate(g, xoff=-patch_x, yoff=-patch_y), xfact=inv, yfact=inv, origin=(0, 0))
222    )
223    shapes = ((geom, i + 1) for i, geom in enumerate(local.geometry))
224    return rio_rasterize(shapes, out_shape=(patch_size, patch_size), fill=0, dtype=np.int32)
225
226
227def _rasterize_patch_semantic(
228    patch_x: int,
229    patch_y: int,
230    patch_size: int,
231    cells_gdf: "gpd.GeoDataFrame",  # noqa
232    spot_labels: np.ndarray,
233    native_scale: float = 1.0,
234    spot_tree=None,
235) -> np.ndarray:
236    """Rasterize nucleus polygons within one patch to a semantic (cell-type) mask.
237
238    native_scale: native WSI pixels per 0.5 um/px patch pixel (= 0.5 / pixel_size_um).
239    spot_labels: (N, 3) array of (x, y, label) for each ST spot in native WSI coordinates.
240    spot_tree: pre-built cKDTree over spot_labels[:, :2]. Built locally if None (slow per-patch).
241    Each nucleus is assigned the label of its nearest ST spot via KDTree.
242    """
243    try:
244        from shapely.geometry import box
245        from shapely.affinity import translate, scale as affine_scale
246        from rasterio.features import rasterize as rio_rasterize
247        from scipy.spatial import cKDTree
248    except ImportError:
249        raise ImportError("rasterio, shapely, and scipy are required. Install with: pip install geopandas rasterio scipy")  # noqa
250
251    native_size = round(patch_size * native_scale)
252    patch_box = box(patch_x, patch_y, patch_x + native_size, patch_y + native_size)
253    local = cells_gdf[cells_gdf.geometry.intersects(patch_box)].copy()
254    mask = np.zeros((patch_size, patch_size), dtype=np.int32)
255    if local.empty:
256        return mask
257
258    # Assign each nucleus its nearest ST spot's label via KDTree on native coords.
259    tree = spot_tree if spot_tree is not None else cKDTree(spot_labels[:, :2])
260    centroids = np.array([[g.centroid.x, g.centroid.y] for g in local.geometry])
261    _, idx = tree.query(centroids)
262    local["label"] = spot_labels[idx, 2].astype(int)
263
264    inv = 1.0 / native_scale
265    local["geometry"] = local["geometry"].apply(
266        lambda g: affine_scale(translate(g, xoff=-patch_x, yoff=-patch_y), xfact=inv, yfact=inv, origin=(0, 0))
267    )
268    shapes = ((geom, int(label)) for geom, label in zip(local.geometry, local["label"]))
269    return rio_rasterize(shapes, out_shape=(patch_size, patch_size), fill=0, dtype=np.int32)
270
271
272def _load_panglaodb(cache_path: str) -> "pd.DataFrame":  # noqa
273    """Download (once) and return the PanglaoDB marker-gene table."""
274    try:
275        import pandas as pd
276    except ImportError:
277        raise ImportError("pandas is required. Install with: pip install pandas")
278
279    tsv_path = os.path.join(cache_path, "PanglaoDB_markers.tsv.gz")
280    if not os.path.exists(tsv_path):
281        import urllib.request
282        os.makedirs(cache_path, exist_ok=True)
283        req = urllib.request.Request(PANGLAODB_URL, headers={"User-Agent": "Mozilla/5.0"})
284        with urllib.request.urlopen(req) as resp, open(tsv_path, "wb") as fh:
285            fh.write(resp.read())
286
287    df = pd.read_csv(tsv_path, sep="\t")
288    # Keep only human genes.
289    df = df[df["species"].str.contains("Hs", na=False)]
290    return df[["official gene symbol", "cell type"]].copy()
291
292
293def _cell_type_to_category(cell_type_name: str) -> str:
294    """Map a PanglaoDB cell type name to one of the four coarse categories."""
295    name = cell_type_name.lower()
296    for kw in EPITHELIAL_KEYWORDS:
297        if kw in name:
298            return "Epithelial"
299    for kw in INFLAMMATORY_KEYWORDS:
300        if kw in name:
301            return "Inflammatory"
302    for kw in CONNECTIVE_KEYWORDS:
303        if kw in name:
304            return "Connective"
305    return "Unknown"
306
307
308def _compute_cell_type_map(
309    h5ad_path: str,
310    marker_db: "pd.DataFrame",  # noqa
311    top_n: int = 10,
312    tau_vote: int = 5,
313    top_m: int = 20,
314    tau_cancer: float = 0.25,
315) -> np.ndarray:
316    """Run the ST cell-type assignment pipeline from the paper (arXiv 2604.23481).
317
318    Returns an (N, 3) float32 array of (x, y, label) for each ST spot, where x and y
319    are native WSI pixel coordinates (pxl_col_in_fullres / pxl_row_in_fullres). Callers
320    use a KDTree to assign each segmented nucleus to its nearest ST spot.
321    """
322    try:
323        import scanpy as sc
324    except ImportError:
325        raise ImportError("scanpy is required for semantic labels. Install with: pip install scanpy")
326
327    adata = sc.read_h5ad(h5ad_path)
328
329    if "pxl_col_in_fullres" not in adata.obs.columns or "pxl_row_in_fullres" not in adata.obs.columns:
330        raise ValueError("h5ad missing pxl_col_in_fullres / pxl_row_in_fullres spot coordinates.")
331
332    # Build gene -> category lookup from PanglaoDB.
333    gene_to_cats: Dict[str, List[str]] = {}
334    for gene, ct in zip(marker_db["official gene symbol"], marker_db["cell type"]):
335        cat = _cell_type_to_category(ct)
336        gene_to_cats.setdefault(gene, []).append(cat)
337
338    # Preprocessing and clustering.
339    sc.pp.normalize_total(adata, target_sum=1e4)
340    sc.pp.log1p(adata)
341    sc.pp.pca(adata)
342    sc.pp.neighbors(adata)
343    sc.tl.leiden(adata, resolution=4.0)
344    sc.tl.rank_genes_groups(adata, groupby="leiden", method="wilcoxon")
345
346    cluster_cat: Dict[str, str] = {}
347    for cluster in adata.obs["leiden"].unique():
348        try:
349            top_genes = list(
350                sc.get.rank_genes_groups_df(adata, group=cluster)["names"].iloc[:top_m]
351            )
352        except Exception:
353            cluster_cat[cluster] = "Unknown"
354            continue
355
356        votes: Dict[str, float] = {"Epithelial": 0.0, "Inflammatory": 0.0, "Connective": 0.0}
357        total_vote = 0.0
358        for rank, gene in enumerate(top_genes[:top_n]):
359            weight = top_n - rank
360            for cat in gene_to_cats.get(gene, []):
361                if cat in votes:
362                    votes[cat] += weight
363                    total_vote += weight
364
365        if total_vote < tau_vote:
366            cluster_cat[cluster] = "Unknown"
367            continue
368
369        best = max(votes, key=votes.get)  # type: ignore[arg-type]
370        cluster_cat[cluster] = best
371
372        if best == "Epithelial":
373            cancer_overlap = sum(1 for g in top_genes[:top_m] if g in CANCER_GENES)
374            if cancer_overlap / top_m > tau_cancer:
375                cluster_cat[cluster] = "Neoplastic"
376
377    # Build (N, 3) array: (x, y, label) per ST spot in native WSI pixel coords.
378    xs = adata.obs["pxl_col_in_fullres"].values.astype(np.float32)
379    ys = adata.obs["pxl_row_in_fullres"].values.astype(np.float32)
380    labels = np.array(
381        [CELL_TYPE_LABELS[cluster_cat.get(adata.obs["leiden"].iloc[i], "Unknown")]
382         for i in range(adata.n_obs)],
383        dtype=np.float32,
384    )
385    return np.stack([xs, ys, labels], axis=1)
386
387
388def _preprocess_sample(
389    patches_h5: str,
390    cellvit_geojson: Optional[str],
391    xenium_parquet: Optional[str],
392    h5ad_path: Optional[str],
393    marker_db: Optional["pd.DataFrame"],  # noqa
394    out_h5: str,
395    patch_size: int = 224,
396    pixel_size_um: float = 0.5,
397) -> bool:
398    # Cell coords are in native WSI pixel space; patches are at 0.5 um/px.
399    # native_scale = native WSI pixels per 0.5 um/px patch pixel.
400    native_scale = 0.5 / pixel_size_um
401
402    with h5py.File(patches_h5, "r") as f:
403        img_key = "img" if "img" in f else ("imgs" if "imgs" in f else "images")
404        imgs = f[img_key][:]  # (N, H, W, 3) uint8
405        coords = f["coords"][:]  # (N, 2) top-left (x, y) in native WSI pixels
406
407    n = len(imgs)
408    if n == 0:
409        return False
410
411    # Load GeoDataFrames once per slide.
412    cellvit_gdf = None
413    if cellvit_geojson is not None and os.path.exists(cellvit_geojson):
414        cellvit_gdf = _gdf_from_cellvit_geojson(cellvit_geojson)
415
416    xenium_gdf = None
417    if xenium_parquet is not None and os.path.exists(xenium_parquet):
418        xenium_gdf = _gdf_from_xenium_parquet(xenium_parquet)
419
420    spot_labels: Optional[np.ndarray] = None
421    if h5ad_path is not None and os.path.exists(h5ad_path) and marker_db is not None and xenium_gdf is not None:
422        try:
423            spot_labels = _compute_cell_type_map(h5ad_path, marker_db)
424        except Exception as e:
425            print(f"Warning: semantic labels unavailable for {os.path.basename(h5ad_path)}: {e}")
426
427    # Build the KDTree once per slide rather than once per patch.
428    spot_tree = None
429    if spot_labels is not None:
430        try:
431            from scipy.spatial import cKDTree
432            spot_tree = cKDTree(spot_labels[:, :2])
433        except ImportError:
434            pass
435
436    raw = np.zeros((n, 3, patch_size, patch_size), dtype=np.uint8)
437    instances = np.zeros((n, patch_size, patch_size), dtype=np.int32)
438    xenium_instances = np.zeros((n, patch_size, patch_size), dtype=np.int32)
439    semantic = np.zeros((n, patch_size, patch_size), dtype=np.int32)
440
441    sid = os.path.splitext(os.path.basename(out_h5))[0]
442    for i, (img, coord) in enumerate(tqdm(zip(imgs, coords), total=n, desc=f"Processing {sid}", leave=False)):
443        raw[i] = img[:patch_size, :patch_size, :].transpose(2, 0, 1)
444        px, py = int(coord[0]), int(coord[1])
445
446        if cellvit_gdf is not None:
447            instances[i] = _rasterize_patch_instances(px, py, patch_size, cellvit_gdf, native_scale)
448
449        if xenium_gdf is not None:
450            xenium_instances[i] = _rasterize_patch_instances(px, py, patch_size, xenium_gdf, native_scale)
451
452        if spot_labels is not None and xenium_gdf is not None:
453            semantic[i] = _rasterize_patch_semantic(
454                px, py, patch_size, xenium_gdf, spot_labels, native_scale, spot_tree
455            )
456
457    chunk_2d = (1, patch_size, patch_size)
458    with h5py.File(out_h5, "w") as f:
459        f.create_dataset("raw", data=raw, compression="gzip", chunks=(1, 3, patch_size, patch_size))
460        f.create_dataset(LABEL_KEYS["instances"], data=instances, compression="gzip", chunks=chunk_2d)
461        f.create_dataset(LABEL_KEYS["xenium_instances"], data=xenium_instances, compression="gzip", chunks=chunk_2d)
462        f.create_dataset(LABEL_KEYS["semantic"], data=semantic, compression="gzip", chunks=chunk_2d)
463
464    return True
465
466
467class HESTDataset(Dataset):
468    """2D patch dataset for HEST-1k.
469
470    Indexes all patches across all per-slide H5 files and returns proper 2D tensors:
471    raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
472    """
473
474    def __init__(
475        self,
476        h5_paths: List[str],
477        label_key: str,
478        raw_transform: Optional[Callable] = None,
479        label_transform: Optional[Callable] = None,
480        transform: Optional[Callable] = None,
481        n_samples: Optional[int] = None,
482        seed: Optional[int] = None,
483    ):
484        self._label_key = label_key
485        self._raw_transform = raw_transform
486        self._label_transform = label_transform
487        self._transform = transform
488
489        # Build flat index: list of (h5_path, patch_idx).
490        self._index: List[Tuple[str, int]] = []
491        for h5_path in h5_paths:
492            with h5py.File(h5_path, "r") as f:
493                n = f["raw"].shape[0]  # raw stored as (N, 3, H, W)
494            self._index.extend((h5_path, i) for i in range(n))
495
496        if n_samples is not None:
497            rng = np.random.default_rng(seed)
498            chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index))
499            self._index = [self._index[i] for i in chosen]
500
501    def __len__(self) -> int:
502        return len(self._index)
503
504    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
505        h5_path, patch_idx = self._index[idx]
506        with h5py.File(h5_path, "r") as f:
507            raw = f["raw"][patch_idx].astype(np.float32) / 255.0  # (3, H, W)
508            label = f[self._label_key][patch_idx].astype(np.int32)  # (H, W)
509
510        raw = torch.from_numpy(raw)
511        label = torch.from_numpy(label)
512
513        if self._raw_transform is not None:
514            raw = self._raw_transform(raw)
515        if self._label_transform is not None:
516            label = self._label_transform(label)
517        if self._transform is not None:
518            raw, label = self._transform(raw, label)
519
520        return raw, label
521
522
523def get_hest_data(
524    path: Union[os.PathLike, str],
525    organs: Optional[List[str]] = None,
526    download: bool = False,
527) -> str:
528    """Download and preprocess the HEST-1k dataset.
529
530    Args:
531        path: Filepath to a folder where the downloaded data will be saved.
532        organs: List of organ types to include. Uses all available organs if None.
533            Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481.
534        download: Whether to download the data if it is not present.
535
536    Returns:
537        The filepath to the preprocessed data directory.
538    """
539    preprocessed_dir = os.path.join(path, "preprocessed")
540
541    if download:
542        meta_path = os.path.join(path, METADATA_FILENAME)
543        if not os.path.exists(meta_path):
544            try:
545                from huggingface_hub import hf_hub_download
546            except ImportError:
547                raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
548            hf_hub_download(repo_id=HF_REPO, repo_type="dataset", filename=METADATA_FILENAME, local_dir=path)
549
550        sample_ids = _filter_sample_ids(path, organs)
551        xenium_dir = os.path.join(path, "xenium_seg")
552        st_dir = os.path.join(path, "st")
553        include_xenium = not os.path.exists(xenium_dir)
554        include_st = not os.path.exists(st_dir)
555        _download_hest(path, sample_ids, include_xenium=include_xenium, include_st=include_st)
556    else:
557        sample_ids = [
558            os.path.splitext(os.path.basename(p))[0]
559            for p in glob(os.path.join(path, "patches", "*.h5"))
560        ]
561        if organs is not None:
562            meta_path = os.path.join(path, METADATA_FILENAME)
563            if os.path.exists(meta_path):
564                allowed = set(_filter_sample_ids(path, organs))
565                sample_ids = [s for s in sample_ids if s in allowed]
566
567    # Load PanglaoDB once for all samples.
568    db_cache = os.path.join(path, "_db_cache")
569    try:
570        marker_db = _load_panglaodb(db_cache)
571    except Exception:
572        marker_db = None
573
574    # Build a pixel_size lookup from the metadata (um/px at native resolution).
575    try:
576        meta = _load_metadata(path)
577        pixel_size_map = dict(zip(meta["id"], meta["pixel_size_um_estimated"].fillna(0.5)))
578    except Exception:
579        pixel_size_map = {}
580
581    os.makedirs(preprocessed_dir, exist_ok=True)
582    cellvit_zip_dir = os.path.join(path, "cellvit_seg")
583    cellvit_cache = os.path.join(path, "_cellvit_extracted")
584    xenium_dir = os.path.join(path, "xenium_seg")
585    st_dir = os.path.join(path, "st")
586
587    for sid in tqdm(sample_ids, desc="Preprocessing HEST samples"):
588        out_h5 = os.path.join(preprocessed_dir, f"{sid}.h5")
589        if os.path.exists(out_h5):
590            continue
591
592        patches_h5 = os.path.join(path, "patches", f"{sid}.h5")
593        if not os.path.exists(patches_h5):
594            continue
595
596        geojson_path = _unzip_cellvit(
597            os.path.join(cellvit_zip_dir, f"{sid}_cellvit_seg.geojson.zip"), cellvit_cache
598        )
599        xenium_parquet = os.path.join(xenium_dir, f"{sid}_xenium_nucleus_seg.parquet")
600        h5ad_path = os.path.join(st_dir, f"{sid}.h5ad")
601        pixel_size_um = float(pixel_size_map.get(sid, 0.5))
602
603        _preprocess_sample(
604            patches_h5=patches_h5,
605            cellvit_geojson=geojson_path,
606            xenium_parquet=xenium_parquet if os.path.exists(xenium_parquet) else None,
607            h5ad_path=h5ad_path if os.path.exists(h5ad_path) else None,
608            marker_db=marker_db,
609            out_h5=out_h5,
610            pixel_size_um=pixel_size_um,
611        )
612
613    return preprocessed_dir
614
615
616def get_hest_paths(
617    path: Union[os.PathLike, str],
618    organs: Optional[List[str]] = None,
619    download: bool = False,
620) -> List[str]:
621    """Get paths to the preprocessed HEST-1k H5 files.
622
623    Args:
624        path: Filepath to a folder where the downloaded data will be saved.
625        organs: List of organ types to include. Uses all available organs if None.
626        download: Whether to download the data if it is not present.
627
628    Returns:
629        List of filepaths to the preprocessed H5 files (one per slide).
630    """
631    preprocessed_dir = get_hest_data(path, organs, download)
632    h5_paths = natsorted(glob(os.path.join(preprocessed_dir, "*.h5")))
633    if not h5_paths:
634        raise RuntimeError(f"No preprocessed data found in {preprocessed_dir}.")
635
636    if organs is not None:
637        meta_path = os.path.join(path, METADATA_FILENAME)
638        if os.path.exists(meta_path):
639            allowed = set(_filter_sample_ids(path, organs))
640            h5_paths = [p for p in h5_paths if os.path.splitext(os.path.basename(p))[0] in allowed]
641
642    return h5_paths
643
644
645def get_hest_dataset(
646    path: Union[os.PathLike, str],
647    patch_shape: Tuple[int, int],
648    organs: Optional[List[str]] = None,
649    label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances",
650    download: bool = False,
651    n_samples: Optional[int] = None,
652    seed: Optional[int] = None,
653    raw_transform: Optional[Callable] = None,
654    label_transform: Optional[Callable] = None,
655    transform: Optional[Callable] = None,
656) -> Dataset:
657    """Get the HEST-1k dataset for nuclei segmentation and cell-type classification.
658
659    Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
660
661    Args:
662        path: Filepath to a folder where the downloaded data will be saved.
663        patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
664        organs: List of organ types to include. Uses all available organs if None.
665            Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
666        label_choice: Which label type to return:
667            - 'instances': CellViT nuclei instance masks (H&E-based, all samples).
668            - 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise).
669            - 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise).
670        download: Whether to download the data if it is not present.
671        n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None.
672        seed: Random seed for reproducible patch sampling when n_samples is set.
673        raw_transform: Transform applied to the raw image tensor.
674        label_transform: Transform applied to the label tensor.
675        transform: Joint transform applied to both raw and label.
676
677    Returns:
678        The segmentation dataset.
679    """
680    valid = ("instances", "xenium_instances", "semantic")
681    if label_choice not in valid:
682        raise ValueError(f"'{label_choice}' is not valid. Choose from {valid}.")
683
684    h5_paths = get_hest_paths(path, organs, download)
685    return HESTDataset(
686        h5_paths=h5_paths,
687        label_key=LABEL_KEYS[label_choice],
688        raw_transform=raw_transform,
689        label_transform=label_transform,
690        transform=transform,
691        n_samples=n_samples,
692        seed=seed,
693    )
694
695
696def get_hest_loader(
697    path: Union[os.PathLike, str],
698    batch_size: int,
699    patch_shape: Tuple[int, int],
700    organs: Optional[List[str]] = None,
701    label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances",
702    download: bool = False,
703    n_samples: Optional[int] = None,
704    seed: Optional[int] = None,
705    raw_transform: Optional[Callable] = None,
706    label_transform: Optional[Callable] = None,
707    transform: Optional[Callable] = None,
708    **loader_kwargs,
709) -> DataLoader:
710    """Get the HEST-1k dataloader for nuclei segmentation and cell-type classification.
711
712    Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32.
713
714    Args:
715        path: Filepath to a folder where the downloaded data will be saved.
716        batch_size: The batch size for training.
717        patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
718        organs: List of organ types to include. Uses all available organs if None.
719            Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
720        label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'.
721        download: Whether to download the data if it is not present.
722        n_samples: Number of patches per epoch. Uses all patches if None.
723        seed: Random seed for reproducible patch sampling when n_samples is set.
724        raw_transform: Transform applied to the raw image tensor.
725        label_transform: Transform applied to the label tensor.
726        transform: Joint transform applied to both raw and label.
727        loader_kwargs: Additional keyword arguments for the PyTorch DataLoader.
728
729    Returns:
730        The DataLoader.
731    """
732    dataset = get_hest_dataset(
733        path, patch_shape, organs, label_choice, download, n_samples, seed, raw_transform, label_transform, transform
734    )
735    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
HF_REPO = 'MahmoodLab/hest'
METADATA_FILENAME = 'HEST_v1_3_0.csv'
PANGLAODB_URL = 'https://panglaodb.se/markers/PanglaoDB_markers_27_Mar_2020.tsv.gz'
CELL_TYPE_LABELS = {'Epithelial': 1, 'Inflammatory': 2, 'Connective': 3, 'Neoplastic': 4, 'Unknown': 5}
LABEL_KEYS = {'instances': 'labels/instances/h&e', 'xenium_instances': 'labels/instances/xenium', 'semantic': 'labels/semantic/st'}
PANNUKE_ORGANS = ['Breast', 'Colon', 'Kidney', 'Liver', 'Lung', 'Ovarian', 'Pancreatic', 'Prostate', 'Skin', 'Stomach']
EPITHELIAL_KEYWORDS = ['acinar', 'airway epithelial', 'airway goblet', 'alveolar type', 'alpha cell', 'basal cell', 'beta cell', 'cholangiocyte', 'ciliated', 'clara', 'crypt', 'delta cell', 'ductal', 'enterocyte', 'epithelial', 'goblet', 'hepatocyte', 'keratinocyte', 'mesothelial', 'paneth', 'pneumocyte', 'proximal tubule', 'renal tubule', 'squamous', 'thyroid', 'trophoblast', 'tuft', 'urothelial']
INFLAMMATORY_KEYWORDS = ['alveolar macrophage', 'b cell', 'basophil', 'dendritic', 'eosinophil', 'innate lymphoid', 'lymphocyte', 'macrophage', 'mast cell', 'monocyte', 'natural killer', 'neutrophil', 'nk cell', 'plasma cell', 'regulatory t', 't cell']
CONNECTIVE_KEYWORDS = ['adipocyte', 'chondrocyte', 'endothelial', 'fibroblast', 'mesenchymal', 'myofibroblast', 'osteoblast', 'osteoclast', 'pericyte', 'smooth muscle', 'stellate', 'stromal', 'vascular']
CANCER_GENES = {'KDM6A', 'FLCN', 'CREBBP', 'MDM2', 'MUTYH', 'SETD2', 'TP53', 'FLT3', 'RAD51B', 'SMAD4', 'TSC1', 'BRIP1', 'MET', 'MSH2', 'SDHC', 'TP63', 'AKT1', 'EZH2', 'NBN', 'FH', 'KDM5C', 'GATA3', 'SF3B1', 'ATM', 'PALB2', 'RB1', 'TET2', 'EGFR', 'HRAS', 'MLH1', 'BRAF', 'MTOR', 'CTNNB1', 'SDHD', 'PTCH1', 'EP300', 'IDH2', 'RAD51C', 'CDK12', 'PMS2', 'SUFU', 'RAD50', 'MYC', 'NF1', 'CHEK2', 'FGFR2', 'APC', 'ELOC', 'MAP3K1', 'BRCA2', 'FANCA', 'HNF1A', 'BRCA1', 'PDGFRA', 'KIT', 'SMARCB1', 'PBRM1', 'JAK2', 'TERT', 'NFE2L2', 'AXIN2', 'RNF43', 'KRAS', 'SDHA', 'TSC2', 'SDHB', 'ERBB2', 'CDH1', 'DNMT3A', 'ALK', 'BAP1', 'FGFR3', 'PIK3R1', 'FANCD2', 'PPP2R1A', 'CDKN2A', 'ESR1', 'PTEN', 'POLE', 'SMARCA4', 'NRAS', 'ABL1', 'FGFR1', 'IDH1', 'RUNX1', 'XRCC2', 'AXIN1', 'GNAS', 'PIK3CA', 'MYCN', 'SMO', 'NOTCH1', 'RET', 'MAP2K1', 'NF2', 'STK11', 'VHL', 'KEAP1', 'FBXW7', 'MSH6', 'NOTCH2', 'CHD4', 'RAD51D', 'GNAQ', 'MRE11'}
class HESTDataset(typing.Generic[+_T_co]):
468class HESTDataset(Dataset):
469    """2D patch dataset for HEST-1k.
470
471    Indexes all patches across all per-slide H5 files and returns proper 2D tensors:
472    raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
473    """
474
475    def __init__(
476        self,
477        h5_paths: List[str],
478        label_key: str,
479        raw_transform: Optional[Callable] = None,
480        label_transform: Optional[Callable] = None,
481        transform: Optional[Callable] = None,
482        n_samples: Optional[int] = None,
483        seed: Optional[int] = None,
484    ):
485        self._label_key = label_key
486        self._raw_transform = raw_transform
487        self._label_transform = label_transform
488        self._transform = transform
489
490        # Build flat index: list of (h5_path, patch_idx).
491        self._index: List[Tuple[str, int]] = []
492        for h5_path in h5_paths:
493            with h5py.File(h5_path, "r") as f:
494                n = f["raw"].shape[0]  # raw stored as (N, 3, H, W)
495            self._index.extend((h5_path, i) for i in range(n))
496
497        if n_samples is not None:
498            rng = np.random.default_rng(seed)
499            chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index))
500            self._index = [self._index[i] for i in chosen]
501
502    def __len__(self) -> int:
503        return len(self._index)
504
505    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
506        h5_path, patch_idx = self._index[idx]
507        with h5py.File(h5_path, "r") as f:
508            raw = f["raw"][patch_idx].astype(np.float32) / 255.0  # (3, H, W)
509            label = f[self._label_key][patch_idx].astype(np.int32)  # (H, W)
510
511        raw = torch.from_numpy(raw)
512        label = torch.from_numpy(label)
513
514        if self._raw_transform is not None:
515            raw = self._raw_transform(raw)
516        if self._label_transform is not None:
517            label = self._label_transform(label)
518        if self._transform is not None:
519            raw, label = self._transform(raw, label)
520
521        return raw, label

2D patch dataset for HEST-1k.

Indexes all patches across all per-slide H5 files and returns proper 2D tensors: raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.

HESTDataset( h5_paths: List[str], label_key: str, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, n_samples: Optional[int] = None, seed: Optional[int] = None)
475    def __init__(
476        self,
477        h5_paths: List[str],
478        label_key: str,
479        raw_transform: Optional[Callable] = None,
480        label_transform: Optional[Callable] = None,
481        transform: Optional[Callable] = None,
482        n_samples: Optional[int] = None,
483        seed: Optional[int] = None,
484    ):
485        self._label_key = label_key
486        self._raw_transform = raw_transform
487        self._label_transform = label_transform
488        self._transform = transform
489
490        # Build flat index: list of (h5_path, patch_idx).
491        self._index: List[Tuple[str, int]] = []
492        for h5_path in h5_paths:
493            with h5py.File(h5_path, "r") as f:
494                n = f["raw"].shape[0]  # raw stored as (N, 3, H, W)
495            self._index.extend((h5_path, i) for i in range(n))
496
497        if n_samples is not None:
498            rng = np.random.default_rng(seed)
499            chosen = rng.choice(len(self._index), size=n_samples, replace=n_samples > len(self._index))
500            self._index = [self._index[i] for i in chosen]
def get_hest_data( path: Union[os.PathLike, str], organs: Optional[List[str]] = None, download: bool = False) -> str:
524def get_hest_data(
525    path: Union[os.PathLike, str],
526    organs: Optional[List[str]] = None,
527    download: bool = False,
528) -> str:
529    """Download and preprocess the HEST-1k dataset.
530
531    Args:
532        path: Filepath to a folder where the downloaded data will be saved.
533        organs: List of organ types to include. Uses all available organs if None.
534            Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481.
535        download: Whether to download the data if it is not present.
536
537    Returns:
538        The filepath to the preprocessed data directory.
539    """
540    preprocessed_dir = os.path.join(path, "preprocessed")
541
542    if download:
543        meta_path = os.path.join(path, METADATA_FILENAME)
544        if not os.path.exists(meta_path):
545            try:
546                from huggingface_hub import hf_hub_download
547            except ImportError:
548                raise ImportError("huggingface_hub is required. Install with: pip install huggingface_hub")
549            hf_hub_download(repo_id=HF_REPO, repo_type="dataset", filename=METADATA_FILENAME, local_dir=path)
550
551        sample_ids = _filter_sample_ids(path, organs)
552        xenium_dir = os.path.join(path, "xenium_seg")
553        st_dir = os.path.join(path, "st")
554        include_xenium = not os.path.exists(xenium_dir)
555        include_st = not os.path.exists(st_dir)
556        _download_hest(path, sample_ids, include_xenium=include_xenium, include_st=include_st)
557    else:
558        sample_ids = [
559            os.path.splitext(os.path.basename(p))[0]
560            for p in glob(os.path.join(path, "patches", "*.h5"))
561        ]
562        if organs is not None:
563            meta_path = os.path.join(path, METADATA_FILENAME)
564            if os.path.exists(meta_path):
565                allowed = set(_filter_sample_ids(path, organs))
566                sample_ids = [s for s in sample_ids if s in allowed]
567
568    # Load PanglaoDB once for all samples.
569    db_cache = os.path.join(path, "_db_cache")
570    try:
571        marker_db = _load_panglaodb(db_cache)
572    except Exception:
573        marker_db = None
574
575    # Build a pixel_size lookup from the metadata (um/px at native resolution).
576    try:
577        meta = _load_metadata(path)
578        pixel_size_map = dict(zip(meta["id"], meta["pixel_size_um_estimated"].fillna(0.5)))
579    except Exception:
580        pixel_size_map = {}
581
582    os.makedirs(preprocessed_dir, exist_ok=True)
583    cellvit_zip_dir = os.path.join(path, "cellvit_seg")
584    cellvit_cache = os.path.join(path, "_cellvit_extracted")
585    xenium_dir = os.path.join(path, "xenium_seg")
586    st_dir = os.path.join(path, "st")
587
588    for sid in tqdm(sample_ids, desc="Preprocessing HEST samples"):
589        out_h5 = os.path.join(preprocessed_dir, f"{sid}.h5")
590        if os.path.exists(out_h5):
591            continue
592
593        patches_h5 = os.path.join(path, "patches", f"{sid}.h5")
594        if not os.path.exists(patches_h5):
595            continue
596
597        geojson_path = _unzip_cellvit(
598            os.path.join(cellvit_zip_dir, f"{sid}_cellvit_seg.geojson.zip"), cellvit_cache
599        )
600        xenium_parquet = os.path.join(xenium_dir, f"{sid}_xenium_nucleus_seg.parquet")
601        h5ad_path = os.path.join(st_dir, f"{sid}.h5ad")
602        pixel_size_um = float(pixel_size_map.get(sid, 0.5))
603
604        _preprocess_sample(
605            patches_h5=patches_h5,
606            cellvit_geojson=geojson_path,
607            xenium_parquet=xenium_parquet if os.path.exists(xenium_parquet) else None,
608            h5ad_path=h5ad_path if os.path.exists(h5ad_path) else None,
609            marker_db=marker_db,
610            out_h5=out_h5,
611            pixel_size_um=pixel_size_um,
612        )
613
614    return preprocessed_dir

Download and preprocess the HEST-1k dataset.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • organs: List of organ types to include. Uses all available organs if None. Example: ['Breast', 'Colon']. See PANNUKE_ORGANS for the set used in arXiv 2604.23481.
  • download: Whether to download the data if it is not present.
Returns:

The filepath to the preprocessed data directory.

def get_hest_paths( path: Union[os.PathLike, str], organs: Optional[List[str]] = None, download: bool = False) -> List[str]:
617def get_hest_paths(
618    path: Union[os.PathLike, str],
619    organs: Optional[List[str]] = None,
620    download: bool = False,
621) -> List[str]:
622    """Get paths to the preprocessed HEST-1k H5 files.
623
624    Args:
625        path: Filepath to a folder where the downloaded data will be saved.
626        organs: List of organ types to include. Uses all available organs if None.
627        download: Whether to download the data if it is not present.
628
629    Returns:
630        List of filepaths to the preprocessed H5 files (one per slide).
631    """
632    preprocessed_dir = get_hest_data(path, organs, download)
633    h5_paths = natsorted(glob(os.path.join(preprocessed_dir, "*.h5")))
634    if not h5_paths:
635        raise RuntimeError(f"No preprocessed data found in {preprocessed_dir}.")
636
637    if organs is not None:
638        meta_path = os.path.join(path, METADATA_FILENAME)
639        if os.path.exists(meta_path):
640            allowed = set(_filter_sample_ids(path, organs))
641            h5_paths = [p for p in h5_paths if os.path.splitext(os.path.basename(p))[0] in allowed]
642
643    return h5_paths

Get paths to the preprocessed HEST-1k H5 files.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • organs: List of organ types to include. Uses all available organs if None.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths to the preprocessed H5 files (one per slide).

def get_hest_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], organs: Optional[List[str]] = None, label_choice: Literal['instances', 'xenium_instances', 'semantic'] = 'instances', download: bool = False, n_samples: Optional[int] = None, seed: Optional[int] = None, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None) -> torch.utils.data.dataset.Dataset:
646def get_hest_dataset(
647    path: Union[os.PathLike, str],
648    patch_shape: Tuple[int, int],
649    organs: Optional[List[str]] = None,
650    label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances",
651    download: bool = False,
652    n_samples: Optional[int] = None,
653    seed: Optional[int] = None,
654    raw_transform: Optional[Callable] = None,
655    label_transform: Optional[Callable] = None,
656    transform: Optional[Callable] = None,
657) -> Dataset:
658    """Get the HEST-1k dataset for nuclei segmentation and cell-type classification.
659
660    Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.
661
662    Args:
663        path: Filepath to a folder where the downloaded data will be saved.
664        patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
665        organs: List of organ types to include. Uses all available organs if None.
666            Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
667        label_choice: Which label type to return:
668            - 'instances': CellViT nuclei instance masks (H&E-based, all samples).
669            - 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise).
670            - 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise).
671        download: Whether to download the data if it is not present.
672        n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None.
673        seed: Random seed for reproducible patch sampling when n_samples is set.
674        raw_transform: Transform applied to the raw image tensor.
675        label_transform: Transform applied to the label tensor.
676        transform: Joint transform applied to both raw and label.
677
678    Returns:
679        The segmentation dataset.
680    """
681    valid = ("instances", "xenium_instances", "semantic")
682    if label_choice not in valid:
683        raise ValueError(f"'{label_choice}' is not valid. Choose from {valid}.")
684
685    h5_paths = get_hest_paths(path, organs, download)
686    return HESTDataset(
687        h5_paths=h5_paths,
688        label_key=LABEL_KEYS[label_choice],
689        raw_transform=raw_transform,
690        label_transform=label_transform,
691        transform=transform,
692        n_samples=n_samples,
693        seed=seed,
694    )

Get the HEST-1k dataset for nuclei segmentation and cell-type classification.

Returns a 2D dataset: each item is raw (3, H, W) float32 in [0, 1] and labels (H, W) int32.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
  • organs: List of organ types to include. Uses all available organs if None. Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
  • label_choice: Which label type to return:
    • 'instances': CellViT nuclei instance masks (H&E-based, all samples).
    • 'xenium_instances': DAPI nuclei instance masks (Xenium samples only, zeros otherwise).
    • 'semantic': ST-derived cell-type labels 1-5 (Xenium samples only, zeros otherwise).
  • download: Whether to download the data if it is not present.
  • n_samples: Number of patches to sample (with replacement if larger than total). Uses all if None.
  • seed: Random seed for reproducible patch sampling when n_samples is set.
  • raw_transform: Transform applied to the raw image tensor.
  • label_transform: Transform applied to the label tensor.
  • transform: Joint transform applied to both raw and label.
Returns:

The segmentation dataset.

def get_hest_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], organs: Optional[List[str]] = None, label_choice: Literal['instances', 'xenium_instances', 'semantic'] = 'instances', download: bool = False, n_samples: Optional[int] = None, seed: Optional[int] = None, raw_transform: Optional[Callable] = None, label_transform: Optional[Callable] = None, transform: Optional[Callable] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
697def get_hest_loader(
698    path: Union[os.PathLike, str],
699    batch_size: int,
700    patch_shape: Tuple[int, int],
701    organs: Optional[List[str]] = None,
702    label_choice: Literal["instances", "xenium_instances", "semantic"] = "instances",
703    download: bool = False,
704    n_samples: Optional[int] = None,
705    seed: Optional[int] = None,
706    raw_transform: Optional[Callable] = None,
707    label_transform: Optional[Callable] = None,
708    transform: Optional[Callable] = None,
709    **loader_kwargs,
710) -> DataLoader:
711    """Get the HEST-1k dataloader for nuclei segmentation and cell-type classification.
712
713    Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32.
714
715    Args:
716        path: Filepath to a folder where the downloaded data will be saved.
717        batch_size: The batch size for training.
718        patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
719        organs: List of organ types to include. Uses all available organs if None.
720            Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
721        label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'.
722        download: Whether to download the data if it is not present.
723        n_samples: Number of patches per epoch. Uses all patches if None.
724        seed: Random seed for reproducible patch sampling when n_samples is set.
725        raw_transform: Transform applied to the raw image tensor.
726        label_transform: Transform applied to the label tensor.
727        transform: Joint transform applied to both raw and label.
728        loader_kwargs: Additional keyword arguments for the PyTorch DataLoader.
729
730    Returns:
731        The DataLoader.
732    """
733    dataset = get_hest_dataset(
734        path, patch_shape, organs, label_choice, download, n_samples, seed, raw_transform, label_transform, transform
735    )
736    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the HEST-1k dataloader for nuclei segmentation and cell-type classification.

Returns batches of raw (B, 3, H, W) float32 in [0, 1] and labels (B, H, W) int32.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • batch_size: The batch size for training.
  • patch_shape: Not used for cropping (patches are already 224x224); kept for API consistency.
  • organs: List of organ types to include. Uses all available organs if None. Use PANNUKE_ORGANS for the 10-organ subset from arXiv 2604.23481.
  • label_choice: Which label type to return. One of 'instances', 'xenium_instances', 'semantic'.
  • download: Whether to download the data if it is not present.
  • n_samples: Number of patches per epoch. Uses all patches if None.
  • seed: Random seed for reproducible patch sampling when n_samples is set.
  • raw_transform: Transform applied to the raw image tensor.
  • label_transform: Transform applied to the label tensor.
  • transform: Joint transform applied to both raw and label.
  • loader_kwargs: Additional keyword arguments for the PyTorch DataLoader.
Returns:

The DataLoader.