torch_em.data.datasets.medical.leg_3d_us

The LEG 3D US dataset contains annotations for leg muscle segmentation in 3d ultrasound scans.

NOTE: The label legends are described as follows:

  • background: 0
  • soleus (SOL): 100
  • gastrocnemius medialis (GM): 150
  • gastrocnemuis lateralist (GL): 200

The dataset is located at https://www.cs.cit.tum.de/camp/publications/leg-3d-us-dataset/.

This dataset is from the article: https://doi.org/10.1007/s11548-024-03170-7. Please cite it if you use this dataset in your research.

  1"""The LEG 3D US dataset contains annotations for leg muscle segmentation
  2in 3d ultrasound scans.
  3
  4NOTE: The label legends are described as follows:
  5- background: 0
  6- soleus (SOL): 100
  7- gastrocnemius medialis (GM): 150
  8- gastrocnemuis lateralist (GL): 200
  9
 10The dataset is located at https://www.cs.cit.tum.de/camp/publications/leg-3d-us-dataset/.
 11
 12This dataset is from the article: https://doi.org/10.1007/s11548-024-03170-7.
 13Please cite it if you use this dataset in your research.
 14"""
 15
 16import os
 17from glob import glob
 18from tqdm import tqdm
 19from natsort import natsorted
 20from typing import Union, Tuple, Literal, List
 21
 22import numpy as np
 23
 24from torch.utils.data import Dataset, DataLoader
 25
 26import torch_em
 27
 28from .. import util
 29
 30
 31URLS = {
 32    "train": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_train_data.zip",
 33    "val": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_validation_data.zip",
 34    "test": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_test_data.zip",
 35}
 36
 37CHECKSUMS = {
 38    "train": "747e9ada7135979218d93022ac46d40a3a85119e2ea7aebcda4b13f7dfda70d6",
 39    "val": "c204fa0759dd279de722a423401da60657bc0d1ab5f57d135cd0ad55c32af70f",
 40    "test": "42ad341e8133f827d35f9cb3afde3ffbe5ae97dc2af448b6f9af6d4ea6ac99f0",
 41}
 42
 43
 44def get_leg_3d_us_data(
 45    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 46):
 47    """Download the LEG 3D US data.
 48
 49    Args:
 50        path: Filepath to a folder where the data is downloaded for further processing.
 51        split: The data split to use. Either 'train', 'val' or 'test'.
 52        download: Whether to download the data if it is not present.
 53    """
 54    data_dir = os.path.join(path, split)
 55    if os.path.exists(data_dir):
 56        return
 57
 58    os.makedirs(path, exist_ok=True)
 59
 60    if split not in URLS:
 61        raise ValueError(f"'{split}' is not a valid split choice.")
 62
 63    zip_name = "validation" if split == "val" else split
 64    zip_path = os.path.join(path, f"leg_{zip_name}_data.zip")
 65    util.download_source(path=zip_path, url=URLS[split], download=download, checksum=CHECKSUMS[split])
 66    util.unzip(zip_path=zip_path, dst=path)
 67
 68
 69def _preprocess_labels(label_paths):
 70    neu_label_paths = []
 71    for lpath in tqdm(label_paths, desc="Preprocessing labels"):
 72        neu_label_path = lpath.replace(".mha", "_preprocessed.mha")
 73        neu_label_paths.append(neu_label_path)
 74        if os.path.exists(neu_label_path):
 75            continue
 76
 77        import SimpleITK as sitk
 78
 79        labels = sitk.ReadImage(lpath)
 80        larray = sitk.GetArrayFromImage(labels)
 81
 82        # NOTE: Remove other label ids not matching the specified task.
 83        valid_labels = [100, 150, 200]
 84        larray[~np.isin(larray, valid_labels)] = 0
 85
 86        for i, lid in enumerate(valid_labels, start=1):
 87            larray[larray == lid] = i
 88
 89        sitk_label = sitk.GetImageFromArray(larray)
 90        sitk.WriteImage(sitk_label, neu_label_path)
 91
 92    return neu_label_paths
 93
 94
 95def get_leg_3d_us_paths(
 96    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 97) -> Tuple[List[str], List[str]]:
 98    """Get paths to the LEG 3D US data.
 99
100    Args:
101        path: Filepath to a folder where the data is downloaded for further processing.
102        split: The data split to use. Either 'train', 'val' or 'test'.
103        download: Whether to download the data if it is not present.
104
105    Returns:
106        List of filepaths for the image data.
107        List of filepaths for the label data.
108    """
109    get_leg_3d_us_data(path, split, download)
110
111    raw_paths = natsorted(glob(os.path.join(path, split, "*", "x*.mha")))
112    label_paths = [fpath.replace("x", "masksX") for fpath in raw_paths]
113    label_paths = _preprocess_labels(label_paths)
114
115    assert len(raw_paths) == len(label_paths)
116
117    return raw_paths, label_paths
118
119
120def get_leg_3d_us_dataset(
121    path: Union[os.PathLike, str],
122    patch_shape: Tuple[int, ...],
123    split: Literal['train', 'val', 'test'],
124    resize_inputs: bool = False,
125    download: bool = False,
126    **kwargs
127) -> Dataset:
128    """Get the LEG 3D US dataset for leg muscle segmentation.
129
130    Args:
131        path: Filepath to a folder where the data is downloaded for further processing.
132        patch_shape: The patch shape to use for training.
133        split: The data split to use. Either 'train', 'val' or 'test'.
134        resize_inputs:  Whether to resize inputs to the desired patch shape.
135        download: Whether to download the data if it is not present.
136        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
137
138    Returns:
139        The segmentation dataset.
140    """
141    raw_paths, label_paths = get_leg_3d_us_paths(path, split, download)
142
143    if resize_inputs:
144        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
145        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
146            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
147        )
148
149    dataset = torch_em.default_segmentation_dataset(
150        raw_paths=raw_paths,
151        raw_key=None,
152        label_paths=label_paths,
153        label_key=None,
154        patch_shape=patch_shape,
155        is_seg_dataset=True,
156        **kwargs
157    )
158
159    for d in dataset.datasets:
160        d.max_sampling_attempts = 10000
161
162    return dataset
163
164
165def get_leg_3d_us_loader(
166    path: Union[os.PathLike, str],
167    batch_size: int,
168    patch_shape: Tuple[int, ...],
169    split: Literal['train', 'val', 'test'],
170    resize_inputs: bool = False,
171    download: bool = False,
172    **kwargs
173) -> DataLoader:
174    """Get the LEG 3D US dataloader for leg muscle segmentation.
175
176    Args:
177        path: Filepath to a folder where the data is downloaded for further processing.
178        batch_size: The batch size for training.
179        patch_shape: The patch shape to use for training.
180        split: The data split to use. Either 'train', 'val' or 'test'.
181        resize_inputs:  Whether to resize inputs to the desired patch shape.
182        download: Whether to download the data if it is not present.
183        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
184
185    Returns:
186        The DataLoader.
187    """
188    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
189    dataset = get_leg_3d_us_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
190    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS = {'train': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_train_data.zip', 'val': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_validation_data.zip', 'test': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_test_data.zip'}
CHECKSUMS = {'train': '747e9ada7135979218d93022ac46d40a3a85119e2ea7aebcda4b13f7dfda70d6', 'val': 'c204fa0759dd279de722a423401da60657bc0d1ab5f57d135cd0ad55c32af70f', 'test': '42ad341e8133f827d35f9cb3afde3ffbe5ae97dc2af448b6f9af6d4ea6ac99f0'}
def get_leg_3d_us_data( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False):
45def get_leg_3d_us_data(
46    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
47):
48    """Download the LEG 3D US data.
49
50    Args:
51        path: Filepath to a folder where the data is downloaded for further processing.
52        split: The data split to use. Either 'train', 'val' or 'test'.
53        download: Whether to download the data if it is not present.
54    """
55    data_dir = os.path.join(path, split)
56    if os.path.exists(data_dir):
57        return
58
59    os.makedirs(path, exist_ok=True)
60
61    if split not in URLS:
62        raise ValueError(f"'{split}' is not a valid split choice.")
63
64    zip_name = "validation" if split == "val" else split
65    zip_path = os.path.join(path, f"leg_{zip_name}_data.zip")
66    util.download_source(path=zip_path, url=URLS[split], download=download, checksum=CHECKSUMS[split])
67    util.unzip(zip_path=zip_path, dst=path)

Download the LEG 3D US data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • download: Whether to download the data if it is not present.
def get_leg_3d_us_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
 96def get_leg_3d_us_paths(
 97    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 98) -> Tuple[List[str], List[str]]:
 99    """Get paths to the LEG 3D US data.
100
101    Args:
102        path: Filepath to a folder where the data is downloaded for further processing.
103        split: The data split to use. Either 'train', 'val' or 'test'.
104        download: Whether to download the data if it is not present.
105
106    Returns:
107        List of filepaths for the image data.
108        List of filepaths for the label data.
109    """
110    get_leg_3d_us_data(path, split, download)
111
112    raw_paths = natsorted(glob(os.path.join(path, split, "*", "x*.mha")))
113    label_paths = [fpath.replace("x", "masksX") for fpath in raw_paths]
114    label_paths = _preprocess_labels(label_paths)
115
116    assert len(raw_paths) == len(label_paths)
117
118    return raw_paths, label_paths

Get paths to the LEG 3D US data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the image data. List of filepaths for the label data.

def get_leg_3d_us_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
121def get_leg_3d_us_dataset(
122    path: Union[os.PathLike, str],
123    patch_shape: Tuple[int, ...],
124    split: Literal['train', 'val', 'test'],
125    resize_inputs: bool = False,
126    download: bool = False,
127    **kwargs
128) -> Dataset:
129    """Get the LEG 3D US dataset for leg muscle segmentation.
130
131    Args:
132        path: Filepath to a folder where the data is downloaded for further processing.
133        patch_shape: The patch shape to use for training.
134        split: The data split to use. Either 'train', 'val' or 'test'.
135        resize_inputs:  Whether to resize inputs to the desired patch shape.
136        download: Whether to download the data if it is not present.
137        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
138
139    Returns:
140        The segmentation dataset.
141    """
142    raw_paths, label_paths = get_leg_3d_us_paths(path, split, download)
143
144    if resize_inputs:
145        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
146        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
147            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
148        )
149
150    dataset = torch_em.default_segmentation_dataset(
151        raw_paths=raw_paths,
152        raw_key=None,
153        label_paths=label_paths,
154        label_key=None,
155        patch_shape=patch_shape,
156        is_seg_dataset=True,
157        **kwargs
158    )
159
160    for d in dataset.datasets:
161        d.max_sampling_attempts = 10000
162
163    return dataset

Get the LEG 3D US dataset for leg muscle segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_leg_3d_us_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
166def get_leg_3d_us_loader(
167    path: Union[os.PathLike, str],
168    batch_size: int,
169    patch_shape: Tuple[int, ...],
170    split: Literal['train', 'val', 'test'],
171    resize_inputs: bool = False,
172    download: bool = False,
173    **kwargs
174) -> DataLoader:
175    """Get the LEG 3D US dataloader for leg muscle segmentation.
176
177    Args:
178        path: Filepath to a folder where the data is downloaded for further processing.
179        batch_size: The batch size for training.
180        patch_shape: The patch shape to use for training.
181        split: The data split to use. Either 'train', 'val' or 'test'.
182        resize_inputs:  Whether to resize inputs to the desired patch shape.
183        download: Whether to download the data if it is not present.
184        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
185
186    Returns:
187        The DataLoader.
188    """
189    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
190    dataset = get_leg_3d_us_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs)
191    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the LEG 3D US dataloader for leg muscle segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The data split to use. Either 'train', 'val' or 'test'.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.