torch_em.data.datasets.histopathology.srsanet

The SRSA-Net dataset contains annotations for nucleus segmentation in IHC stained TMA histological images from NSCLC patients.

The dataset is located at https://doi.org/10.5281/zenodo.7647846. This dataset is from the publication https://doi.org/10.1016/j.bspc.2024.106143. Please cite it if you use this dataset for your research.

  1"""The SRSA-Net dataset contains annotations for nucleus segmentation
  2in IHC stained TMA histological images from NSCLC patients.
  3
  4The dataset is located at https://doi.org/10.5281/zenodo.7647846.
  5This dataset is from the publication https://doi.org/10.1016/j.bspc.2024.106143.
  6Please cite it if you use this dataset for your research.
  7"""
  8
  9import os
 10from glob import glob
 11from tqdm import tqdm
 12from pathlib import Path
 13from natsort import natsorted
 14from typing import Union, Tuple, Literal, List
 15
 16import numpy as np
 17import imageio.v3 as imageio
 18from skimage.measure import label as connected_components
 19
 20import torch_em
 21
 22from torch.utils.data import Dataset, DataLoader
 23
 24from .. import util
 25
 26
 27URL = "https://zenodo.org/records/7647846/files/IHC_TMA_dataset.zip"
 28CHECKSUM = "9dcc1c94b5d8af5383d3c91141617b1621904ee9bd6f69d2223e7f4363cc80d9"
 29
 30
 31def _preprocess_data(data_dir):
 32    preprocessed_label_dir = os.path.join(data_dir, "preprocessed_labels")
 33    os.makedirs(preprocessed_label_dir, exist_ok=True)
 34
 35    label_paths = glob(os.path.join(data_dir, "masks", "*.npy"))
 36    for lpath in tqdm(label_paths, desc="Preprocessing labels"):
 37        fname = Path(lpath).stem
 38        larray = np.load(lpath)
 39        labels = larray[0] + larray[1]
 40        labels = connected_components(labels)
 41
 42        imageio.imwrite(os.path.join(preprocessed_label_dir, f"{fname}.tif"), labels, compression="zlib")
 43
 44
 45def get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 46    """Download the SRSA-Net dataset for nucleus segmentation.
 47
 48    Args:
 49        path: Filepath to a folder where the downloaded data will be saved.
 50        download: Whether to download the data if it is not present.
 51
 52    Returns:
 53        The filepath to the downloaded data.
 54    """
 55    data_dir = os.path.join(path, "IHC_TMA_dataset")
 56    if os.path.exists(data_dir):
 57        return data_dir
 58
 59    os.makedirs(path, exist_ok=True)
 60
 61    zip_path = os.path.join(path, "IHC_TMA_dataset.zip")
 62    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 63    util.unzip(zip_path=zip_path, dst=path)
 64
 65    _preprocess_data(data_dir)
 66
 67    return data_dir
 68
 69
 70def get_srsanet_paths(
 71    path: Union[os.PathLike, str],
 72    split: Literal['train', 'val', 'test'],
 73    download: bool = False
 74) -> Tuple[List[int], List[int]]:
 75    """Get paths to the SRSA-Net data.
 76
 77    Args:
 78        path: Filepath to a folder where the downloaded data will be saved.
 79        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
 80        download: Whether to download the data if it is not present.
 81
 82    Returns:
 83        List of filepaths to the image data.
 84        List of filepaths to the label data.
 85    """
 86    data_dir = get_srsanet_data(path, download)
 87
 88    if split == "train":
 89        dname = "fold1"
 90    elif split == "val":
 91        dname = "fold2"
 92    elif split == "test":
 93        dname = "fold3"
 94    else:
 95        raise ValueError(f"'{split}' is not a valid split choice.")
 96
 97    raw_paths = natsorted(glob(os.path.join(data_dir, "images", f"{dname}_*.png")))
 98    label_paths = natsorted(glob(os.path.join(data_dir, "preprocessed_labels", f"{dname}_*.tif")))
 99
100    return raw_paths, label_paths
101
102
103def get_srsanet_dataset(
104    path: Union[os.PathLike, str],
105    patch_shape: Tuple[int, int],
106    split: Literal['train', 'val', 'test'],
107    resize_inputs: bool = False,
108    download: bool = False,
109    **kwargs
110) -> Dataset:
111    """Get the SRSA-Net dataset for nucleus segmentation.
112
113    Args:
114        path: Filepath to a folder where the downloaded data will be saved.
115        patch_shape: The patch shape to use for training.
116        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
117        resize_inputs: Whether to resize the inputs.
118        download: Whether to download the data if it is not present.
119        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
120
121    Returns:
122        The segmentation dataset.
123    """
124    raw_paths, label_paths = get_srsanet_paths(path, split, download)
125
126    if resize_inputs:
127        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
128        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
129            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
130        )
131
132    return torch_em.default_segmentation_dataset(
133        raw_paths=raw_paths,
134        raw_key=None,
135        label_paths=label_paths,
136        label_key=None,
137        patch_shape=patch_shape,
138        is_seg_dataset=False,
139        **kwargs
140    )
141
142
143def get_srsanet_loader(
144    path: Union[os.PathLike, str],
145    batch_size: int,
146    patch_shape: Tuple[int, int],
147    split: Literal['train', 'val', 'test'],
148    resize_inputs: bool = False,
149    download: bool = False,
150    **kwargs
151) -> DataLoader:
152    """Get the SRSA-Net dataloader for nucleus segmentation.
153
154    Args:
155        path: Filepath to a folder where the downloaded data will be saved.
156        patch_shape: The patch shape to use for training.
157        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
158        resize_inputs: Whether to resize the inputs.
159        download: Whether to download the data if it is not present.
160        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
161
162    Returns:
163        The DataLoader.
164    """
165    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
166    dataset = get_srsanet_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
167    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/7647846/files/IHC_TMA_dataset.zip'
CHECKSUM = '9dcc1c94b5d8af5383d3c91141617b1621904ee9bd6f69d2223e7f4363cc80d9'
def get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str:
46def get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str:
47    """Download the SRSA-Net dataset for nucleus segmentation.
48
49    Args:
50        path: Filepath to a folder where the downloaded data will be saved.
51        download: Whether to download the data if it is not present.
52
53    Returns:
54        The filepath to the downloaded data.
55    """
56    data_dir = os.path.join(path, "IHC_TMA_dataset")
57    if os.path.exists(data_dir):
58        return data_dir
59
60    os.makedirs(path, exist_ok=True)
61
62    zip_path = os.path.join(path, "IHC_TMA_dataset.zip")
63    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
64    util.unzip(zip_path=zip_path, dst=path)
65
66    _preprocess_data(data_dir)
67
68    return data_dir

Download the SRSA-Net dataset for nucleus segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • download: Whether to download the data if it is not present.
Returns:

The filepath to the downloaded data.

def get_srsanet_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[int], List[int]]:
 71def get_srsanet_paths(
 72    path: Union[os.PathLike, str],
 73    split: Literal['train', 'val', 'test'],
 74    download: bool = False
 75) -> Tuple[List[int], List[int]]:
 76    """Get paths to the SRSA-Net data.
 77
 78    Args:
 79        path: Filepath to a folder where the downloaded data will be saved.
 80        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
 81        download: Whether to download the data if it is not present.
 82
 83    Returns:
 84        List of filepaths to the image data.
 85        List of filepaths to the label data.
 86    """
 87    data_dir = get_srsanet_data(path, download)
 88
 89    if split == "train":
 90        dname = "fold1"
 91    elif split == "val":
 92        dname = "fold2"
 93    elif split == "test":
 94        dname = "fold3"
 95    else:
 96        raise ValueError(f"'{split}' is not a valid split choice.")
 97
 98    raw_paths = natsorted(glob(os.path.join(data_dir, "images", f"{dname}_*.png")))
 99    label_paths = natsorted(glob(os.path.join(data_dir, "preprocessed_labels", f"{dname}_*.tif")))
100
101    return raw_paths, label_paths

Get paths to the SRSA-Net data.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • split: The split to use for the dataset. Either 'train', 'val' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths to the image data. List of filepaths to the label data.

def get_srsanet_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
104def get_srsanet_dataset(
105    path: Union[os.PathLike, str],
106    patch_shape: Tuple[int, int],
107    split: Literal['train', 'val', 'test'],
108    resize_inputs: bool = False,
109    download: bool = False,
110    **kwargs
111) -> Dataset:
112    """Get the SRSA-Net dataset for nucleus segmentation.
113
114    Args:
115        path: Filepath to a folder where the downloaded data will be saved.
116        patch_shape: The patch shape to use for training.
117        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
118        resize_inputs: Whether to resize the inputs.
119        download: Whether to download the data if it is not present.
120        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
121
122    Returns:
123        The segmentation dataset.
124    """
125    raw_paths, label_paths = get_srsanet_paths(path, split, download)
126
127    if resize_inputs:
128        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
129        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
130            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
131        )
132
133    return torch_em.default_segmentation_dataset(
134        raw_paths=raw_paths,
135        raw_key=None,
136        label_paths=label_paths,
137        label_key=None,
138        patch_shape=patch_shape,
139        is_seg_dataset=False,
140        **kwargs
141    )

Get the SRSA-Net dataset for nucleus segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • split: The split to use for the dataset. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_srsanet_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
144def get_srsanet_loader(
145    path: Union[os.PathLike, str],
146    batch_size: int,
147    patch_shape: Tuple[int, int],
148    split: Literal['train', 'val', 'test'],
149    resize_inputs: bool = False,
150    download: bool = False,
151    **kwargs
152) -> DataLoader:
153    """Get the SRSA-Net dataloader for nucleus segmentation.
154
155    Args:
156        path: Filepath to a folder where the downloaded data will be saved.
157        patch_shape: The patch shape to use for training.
158        split: The split to use for the dataset. Either 'train', 'val' or 'test'.
159        resize_inputs: Whether to resize the inputs.
160        download: Whether to download the data if it is not present.
161        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
162
163    Returns:
164        The DataLoader.
165    """
166    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
167    dataset = get_srsanet_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
168    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the SRSA-Net dataloader for nucleus segmentation.

Arguments:
  • path: Filepath to a folder where the downloaded data will be saved.
  • patch_shape: The patch shape to use for training.
  • split: The split to use for the dataset. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize the inputs.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.