torch_em.data.datasets.medical.psfhs

The PSFHS dataset contains annotations for segmentation of pubic symphysis and fetal head in ultrasound images.

This dataset is located at https://zenodo.org/records/10969427. The dataset is from the publication https://doi.org/10.1038/s41597-024-03266-4. Please cite it if you use this dataset for your research.

  1"""The PSFHS dataset contains annotations for segmentation of pubic symphysis and fetal head
  2in ultrasound images.
  3
  4This dataset is located at https://zenodo.org/records/10969427.
  5The dataset is from the publication https://doi.org/10.1038/s41597-024-03266-4.
  6Please cite it if you use this dataset for your research.
  7"""
  8
  9import os
 10from glob import glob
 11from natsort import natsorted
 12from typing import Union, Tuple, Literal, List
 13
 14from torch.utils.data import Dataset, DataLoader
 15
 16import torch_em
 17
 18from .. import util
 19
 20
 21URL = "https://zenodo.org/records/10969427/files/PSFHS.zip"
 22CHECKSUM = "3f4a8126c84640e4d1b8a4e296d0dfd599cea6529b64b9ee00e5489bfd17ea95"
 23
 24
 25def get_psfhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
 26    """Download the PSFHS data.
 27
 28    Args:
 29        path: Filepath to a folder where the data is downloaded for further processing.
 30        download: Whether to download the data if it is not present.
 31
 32    Returns:
 33        Filepath where the data is downloaded.
 34    """
 35    data_dir = os.path.join(path, "PSFHS")
 36    if os.path.exists(data_dir):
 37        return data_dir
 38
 39    os.makedirs(path, exist_ok=True)
 40
 41    zip_path = os.path.join(path, "PSFHS.zip")
 42    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
 43    util.unzip(zip_path=zip_path, dst=path)
 44
 45    return data_dir
 46
 47
 48def get_psfhs_paths(
 49    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 50) -> Tuple[List[int], List[int]]:
 51    """Get paths to the PSFHS dataset.
 52
 53    Args:
 54        path: Filepath to a folder where the data is downloaded for further processing.
 55        split: The choice of data split.
 56        download: Whether to download the data if it is not present.
 57
 58    Returns:
 59        List of filepaths for the image data.
 60        List of filepaths for the label data.
 61    """
 62    data_dir = get_psfhs_data(path, download)
 63
 64    raw_paths = natsorted(glob(os.path.join(data_dir, "image_mha", "*.mha")))
 65    label_paths = natsorted(glob(os.path.join(data_dir, "label_mha", "*.mha")))
 66
 67    if split == "train":
 68        raw_paths, label_paths = raw_paths[:900], label_paths[:900]
 69    elif split == "val":
 70        raw_paths, label_paths = raw_paths[900:1050], label_paths[900:1050]
 71    elif split == "test":
 72        raw_paths, label_paths = raw_paths[1050:], label_paths[1050:]
 73    else:
 74        raise ValueError(f"'{split}' is not a valid split.")
 75
 76    assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0
 77
 78    return raw_paths, label_paths
 79
 80
 81def get_psfhs_dataset(
 82    path: Union[os.PathLike, str],
 83    patch_shape: Tuple[int, int],
 84    split: Literal['train', 'val', 'test'],
 85    resize_inputs: bool = False,
 86    download: bool = False,
 87    **kwargs
 88) -> Dataset:
 89    """Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.
 90
 91    Args:
 92        path: Filepath to a folder where the data is downloaded for further processing.
 93        patch_shape: The patch shape to use for training.
 94        split: The choice of data split.
 95        resize_inputs: Whether to resize the inputs to the patch shape.
 96        download: Whether to download the data if it is not present.
 97        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 98
 99    Returns:
100        The segmentation dataset.
101    """
102    raw_paths, label_paths = get_psfhs_paths(path, split, download)
103
104    if resize_inputs:
105        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
106        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
107            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
108        )
109
110    return torch_em.default_segmentation_dataset(
111        raw_paths=raw_paths,
112        raw_key=None,
113        label_paths=label_paths,
114        label_key=None,
115        ndim=2,
116        is_seg_dataset=False,
117        with_channels=True,
118        patch_shape=patch_shape,
119        **kwargs
120    )
121
122
123def get_psfhs_loader(
124    path: Union[os.PathLike, str],
125    batch_size: int,
126    patch_shape: Tuple[int, int],
127    split: Literal['train', 'val', 'test'],
128    resize_inputs: bool = False,
129    download: bool = False,
130    **kwargs
131) -> DataLoader:
132    """Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.
133
134    Args:
135        path: Filepath to a folder where the data is downloaded for further processing.
136        batch_size: The batch size for training.
137        patch_shape: The patch shape to use for training.
138        split: The choice of data split.
139        download: Whether to download the data if it is not present.
140        resize_inputs: Whether to resize the inputs to the patch shape.
141        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
142
143    Returns:
144        The segmentation dataset.
145    """
146    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
147    dataset = get_psfhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
148    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URL = 'https://zenodo.org/records/10969427/files/PSFHS.zip'
CHECKSUM = '3f4a8126c84640e4d1b8a4e296d0dfd599cea6529b64b9ee00e5489bfd17ea95'
def get_psfhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
26def get_psfhs_data(path: Union[os.PathLike, str], download: bool = False) -> str:
27    """Download the PSFHS data.
28
29    Args:
30        path: Filepath to a folder where the data is downloaded for further processing.
31        download: Whether to download the data if it is not present.
32
33    Returns:
34        Filepath where the data is downloaded.
35    """
36    data_dir = os.path.join(path, "PSFHS")
37    if os.path.exists(data_dir):
38        return data_dir
39
40    os.makedirs(path, exist_ok=True)
41
42    zip_path = os.path.join(path, "PSFHS.zip")
43    util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM)
44    util.unzip(zip_path=zip_path, dst=path)
45
46    return data_dir

Download the PSFHS data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
Returns:

Filepath where the data is downloaded.

def get_psfhs_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[int], List[int]]:
49def get_psfhs_paths(
50    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
51) -> Tuple[List[int], List[int]]:
52    """Get paths to the PSFHS dataset.
53
54    Args:
55        path: Filepath to a folder where the data is downloaded for further processing.
56        split: The choice of data split.
57        download: Whether to download the data if it is not present.
58
59    Returns:
60        List of filepaths for the image data.
61        List of filepaths for the label data.
62    """
63    data_dir = get_psfhs_data(path, download)
64
65    raw_paths = natsorted(glob(os.path.join(data_dir, "image_mha", "*.mha")))
66    label_paths = natsorted(glob(os.path.join(data_dir, "label_mha", "*.mha")))
67
68    if split == "train":
69        raw_paths, label_paths = raw_paths[:900], label_paths[:900]
70    elif split == "val":
71        raw_paths, label_paths = raw_paths[900:1050], label_paths[900:1050]
72    elif split == "test":
73        raw_paths, label_paths = raw_paths[1050:], label_paths[1050:]
74    else:
75        raise ValueError(f"'{split}' is not a valid split.")
76
77    assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0
78
79    return raw_paths, label_paths

Get paths to the PSFHS dataset.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_psfhs_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
 82def get_psfhs_dataset(
 83    path: Union[os.PathLike, str],
 84    patch_shape: Tuple[int, int],
 85    split: Literal['train', 'val', 'test'],
 86    resize_inputs: bool = False,
 87    download: bool = False,
 88    **kwargs
 89) -> Dataset:
 90    """Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.
 91
 92    Args:
 93        path: Filepath to a folder where the data is downloaded for further processing.
 94        patch_shape: The patch shape to use for training.
 95        split: The choice of data split.
 96        resize_inputs: Whether to resize the inputs to the patch shape.
 97        download: Whether to download the data if it is not present.
 98        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 99
100    Returns:
101        The segmentation dataset.
102    """
103    raw_paths, label_paths = get_psfhs_paths(path, split, download)
104
105    if resize_inputs:
106        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True}
107        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
108            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
109        )
110
111    return torch_em.default_segmentation_dataset(
112        raw_paths=raw_paths,
113        raw_key=None,
114        label_paths=label_paths,
115        label_key=None,
116        ndim=2,
117        is_seg_dataset=False,
118        with_channels=True,
119        patch_shape=patch_shape,
120        **kwargs
121    )

Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_psfhs_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, int], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
124def get_psfhs_loader(
125    path: Union[os.PathLike, str],
126    batch_size: int,
127    patch_shape: Tuple[int, int],
128    split: Literal['train', 'val', 'test'],
129    resize_inputs: bool = False,
130    download: bool = False,
131    **kwargs
132) -> DataLoader:
133    """Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.
134
135    Args:
136        path: Filepath to a folder where the data is downloaded for further processing.
137        batch_size: The batch size for training.
138        patch_shape: The patch shape to use for training.
139        split: The choice of data split.
140        download: Whether to download the data if it is not present.
141        resize_inputs: Whether to resize the inputs to the patch shape.
142        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
143
144    Returns:
145        The segmentation dataset.
146    """
147    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
148    dataset = get_psfhs_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
149    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the PSFHS dataset for segmentation of pubic symphysis and fetal head.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
  • resize_inputs: Whether to resize the inputs to the patch shape.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.