torch_em.data.datasets.histopathology.srsanet
The SRSA-Net dataset contains annotations for nucleus segmentation in IHC stained TMA histological images from NSCLC patients.
The dataset is located at https://doi.org/10.5281/zenodo.7647846. This dataset is from the publication https://doi.org/10.1016/j.bspc.2024.106143. Please cite it if you use this dataset for your research.
1"""The SRSA-Net dataset contains annotations for nucleus segmentation 2in IHC stained TMA histological images from NSCLC patients. 3 4The dataset is located at https://doi.org/10.5281/zenodo.7647846. 5This dataset is from the publication https://doi.org/10.1016/j.bspc.2024.106143. 6Please cite it if you use this dataset for your research. 7""" 8 9import os 10from glob import glob 11from tqdm import tqdm 12from pathlib import Path 13from natsort import natsorted 14from typing import Union, Tuple, Literal, List 15 16import numpy as np 17import imageio.v3 as imageio 18from skimage.measure import label as connected_components 19 20import torch_em 21 22from torch.utils.data import Dataset, DataLoader 23 24from .. import util 25 26 27URL = "https://zenodo.org/records/7647846/files/IHC_TMA_dataset.zip" 28CHECKSUM = "9dcc1c94b5d8af5383d3c91141617b1621904ee9bd6f69d2223e7f4363cc80d9" 29 30 31def _preprocess_data(data_dir): 32 preprocessed_label_dir = os.path.join(data_dir, "preprocessed_labels") 33 os.makedirs(preprocessed_label_dir, exist_ok=True) 34 35 label_paths = glob(os.path.join(data_dir, "masks", "*.npy")) 36 for lpath in tqdm(label_paths, desc="Preprocessing labels"): 37 fname = Path(lpath).stem 38 larray = np.load(lpath) 39 labels = larray[0] + larray[1] 40 labels = connected_components(labels) 41 42 imageio.imwrite(os.path.join(preprocessed_label_dir, f"{fname}.tif"), labels, compression="zlib") 43 44 45def get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str: 46 """Download the SRSA-Net dataset for nucleus segmentation. 47 48 Args: 49 path: Filepath to a folder where the downloaded data will be saved. 50 download: Whether to download the data if it is not present. 51 52 Returns: 53 The filepath to the downloaded data. 54 """ 55 data_dir = os.path.join(path, "IHC_TMA_dataset") 56 if os.path.exists(data_dir): 57 return data_dir 58 59 os.makedirs(path, exist_ok=True) 60 61 zip_path = os.path.join(path, "IHC_TMA_dataset.zip") 62 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 63 util.unzip(zip_path=zip_path, dst=path) 64 65 _preprocess_data(data_dir) 66 67 return data_dir 68 69 70def get_srsanet_paths( 71 path: Union[os.PathLike, str], 72 split: Literal['train', 'val', 'test'], 73 download: bool = False 74) -> Tuple[List[int], List[int]]: 75 """Get paths to the SRSA-Net data. 76 77 Args: 78 path: Filepath to a folder where the downloaded data will be saved. 79 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 80 download: Whether to download the data if it is not present. 81 82 Returns: 83 List of filepaths to the image data. 84 List of filepaths to the label data. 85 """ 86 data_dir = get_srsanet_data(path, download) 87 88 if split == "train": 89 dname = "fold1" 90 elif split == "val": 91 dname = "fold2" 92 elif split == "test": 93 dname = "fold3" 94 else: 95 raise ValueError(f"'{split}' is not a valid split choice.") 96 97 raw_paths = natsorted(glob(os.path.join(data_dir, "images", f"{dname}_*.png"))) 98 label_paths = natsorted(glob(os.path.join(data_dir, "preprocessed_labels", f"{dname}_*.tif"))) 99 100 return raw_paths, label_paths 101 102 103def get_srsanet_dataset( 104 path: Union[os.PathLike, str], 105 patch_shape: Tuple[int, int], 106 split: Literal['train', 'val', 'test'], 107 resize_inputs: bool = False, 108 download: bool = False, 109 **kwargs 110) -> Dataset: 111 """Get the SRSA-Net dataset for nucleus segmentation. 112 113 Args: 114 path: Filepath to a folder where the downloaded data will be saved. 115 patch_shape: The patch shape to use for training. 116 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 117 resize_inputs: Whether to resize the inputs. 118 download: Whether to download the data if it is not present. 119 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 120 121 Returns: 122 The segmentation dataset. 123 """ 124 raw_paths, label_paths = get_srsanet_paths(path, split, download) 125 126 if resize_inputs: 127 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 128 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 129 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 130 ) 131 132 return torch_em.default_segmentation_dataset( 133 raw_paths=raw_paths, 134 raw_key=None, 135 label_paths=label_paths, 136 label_key=None, 137 patch_shape=patch_shape, 138 is_seg_dataset=False, 139 **kwargs 140 ) 141 142 143def get_srsanet_loader( 144 path: Union[os.PathLike, str], 145 batch_size: int, 146 patch_shape: Tuple[int, int], 147 split: Literal['train', 'val', 'test'], 148 resize_inputs: bool = False, 149 download: bool = False, 150 **kwargs 151) -> DataLoader: 152 """Get the SRSA-Net dataloader for nucleus segmentation. 153 154 Args: 155 path: Filepath to a folder where the downloaded data will be saved. 156 patch_shape: The patch shape to use for training. 157 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 158 resize_inputs: Whether to resize the inputs. 159 download: Whether to download the data if it is not present. 160 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 161 162 Returns: 163 The DataLoader. 164 """ 165 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 166 dataset = get_srsanet_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 167 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL =
'https://zenodo.org/records/7647846/files/IHC_TMA_dataset.zip'
CHECKSUM =
'9dcc1c94b5d8af5383d3c91141617b1621904ee9bd6f69d2223e7f4363cc80d9'
def
get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str:
46def get_srsanet_data(path: Union[os.PathLike, str], download: bool = False) -> str: 47 """Download the SRSA-Net dataset for nucleus segmentation. 48 49 Args: 50 path: Filepath to a folder where the downloaded data will be saved. 51 download: Whether to download the data if it is not present. 52 53 Returns: 54 The filepath to the downloaded data. 55 """ 56 data_dir = os.path.join(path, "IHC_TMA_dataset") 57 if os.path.exists(data_dir): 58 return data_dir 59 60 os.makedirs(path, exist_ok=True) 61 62 zip_path = os.path.join(path, "IHC_TMA_dataset.zip") 63 util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) 64 util.unzip(zip_path=zip_path, dst=path) 65 66 _preprocess_data(data_dir) 67 68 return data_dir
Download the SRSA-Net dataset for nucleus segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- download: Whether to download the data if it is not present.
Returns:
The filepath to the downloaded data.
def
get_srsanet_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[int], List[int]]:
71def get_srsanet_paths( 72 path: Union[os.PathLike, str], 73 split: Literal['train', 'val', 'test'], 74 download: bool = False 75) -> Tuple[List[int], List[int]]: 76 """Get paths to the SRSA-Net data. 77 78 Args: 79 path: Filepath to a folder where the downloaded data will be saved. 80 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 81 download: Whether to download the data if it is not present. 82 83 Returns: 84 List of filepaths to the image data. 85 List of filepaths to the label data. 86 """ 87 data_dir = get_srsanet_data(path, download) 88 89 if split == "train": 90 dname = "fold1" 91 elif split == "val": 92 dname = "fold2" 93 elif split == "test": 94 dname = "fold3" 95 else: 96 raise ValueError(f"'{split}' is not a valid split choice.") 97 98 raw_paths = natsorted(glob(os.path.join(data_dir, "images", f"{dname}_*.png"))) 99 label_paths = natsorted(glob(os.path.join(data_dir, "preprocessed_labels", f"{dname}_*.tif"))) 100 101 return raw_paths, label_paths
Get paths to the SRSA-Net data.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- split: The split to use for the dataset. Either 'train', 'val' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths to the image data. List of filepaths to the label data.
def
get_srsanet_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
104def get_srsanet_dataset( 105 path: Union[os.PathLike, str], 106 patch_shape: Tuple[int, int], 107 split: Literal['train', 'val', 'test'], 108 resize_inputs: bool = False, 109 download: bool = False, 110 **kwargs 111) -> Dataset: 112 """Get the SRSA-Net dataset for nucleus segmentation. 113 114 Args: 115 path: Filepath to a folder where the downloaded data will be saved. 116 patch_shape: The patch shape to use for training. 117 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 118 resize_inputs: Whether to resize the inputs. 119 download: Whether to download the data if it is not present. 120 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 121 122 Returns: 123 The segmentation dataset. 124 """ 125 raw_paths, label_paths = get_srsanet_paths(path, split, download) 126 127 if resize_inputs: 128 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} 129 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 130 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 131 ) 132 133 return torch_em.default_segmentation_dataset( 134 raw_paths=raw_paths, 135 raw_key=None, 136 label_paths=label_paths, 137 label_key=None, 138 patch_shape=patch_shape, 139 is_seg_dataset=False, 140 **kwargs 141 )
Get the SRSA-Net dataset for nucleus segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- split: The split to use for the dataset. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_srsanet_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
144def get_srsanet_loader( 145 path: Union[os.PathLike, str], 146 batch_size: int, 147 patch_shape: Tuple[int, int], 148 split: Literal['train', 'val', 'test'], 149 resize_inputs: bool = False, 150 download: bool = False, 151 **kwargs 152) -> DataLoader: 153 """Get the SRSA-Net dataloader for nucleus segmentation. 154 155 Args: 156 path: Filepath to a folder where the downloaded data will be saved. 157 patch_shape: The patch shape to use for training. 158 split: The split to use for the dataset. Either 'train', 'val' or 'test'. 159 resize_inputs: Whether to resize the inputs. 160 download: Whether to download the data if it is not present. 161 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 162 163 Returns: 164 The DataLoader. 165 """ 166 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 167 dataset = get_srsanet_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 168 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the SRSA-Net dataloader for nucleus segmentation.
Arguments:
- path: Filepath to a folder where the downloaded data will be saved.
- patch_shape: The patch shape to use for training.
- split: The split to use for the dataset. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize the inputs.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.