torch_em.data.datasets.medical.leg_3d_us
The LEG 3D US dataset contains annotations for leg muscle segmentation in 3d ultrasound scans.
NOTE: The label legends are described as follows:
- background: 0
- soleus (SOL): 100
- gastrocnemius medialis (GM): 150
- gastrocnemuis lateralist (GL): 200
The dataset is located at https://www.cs.cit.tum.de/camp/publications/leg-3d-us-dataset/.
This dataset is from the article: https://doi.org/10.1007/s11548-024-03170-7. Please cite it if you use this dataset in your research.
1"""The LEG 3D US dataset contains annotations for leg muscle segmentation 2in 3d ultrasound scans. 3 4NOTE: The label legends are described as follows: 5- background: 0 6- soleus (SOL): 100 7- gastrocnemius medialis (GM): 150 8- gastrocnemuis lateralist (GL): 200 9 10The dataset is located at https://www.cs.cit.tum.de/camp/publications/leg-3d-us-dataset/. 11 12This dataset is from the article: https://doi.org/10.1007/s11548-024-03170-7. 13Please cite it if you use this dataset in your research. 14""" 15 16import os 17from glob import glob 18from tqdm import tqdm 19from natsort import natsorted 20from typing import Union, Tuple, Literal, List 21 22import numpy as np 23 24from torch.utils.data import Dataset, DataLoader 25 26import torch_em 27 28from .. import util 29 30 31URLS = { 32 "train": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_train_data.zip", 33 "val": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_validation_data.zip", 34 "test": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_test_data.zip", 35} 36 37CHECKSUMS = { 38 "train": "747e9ada7135979218d93022ac46d40a3a85119e2ea7aebcda4b13f7dfda70d6", 39 "val": "c204fa0759dd279de722a423401da60657bc0d1ab5f57d135cd0ad55c32af70f", 40 "test": "42ad341e8133f827d35f9cb3afde3ffbe5ae97dc2af448b6f9af6d4ea6ac99f0", 41} 42 43 44def get_leg_3d_us_data( 45 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 46): 47 """Download the LEG 3D US data. 48 49 Args: 50 path: Filepath to a folder where the data is downloaded for further processing. 51 split: The data split to use. Either 'train', 'val' or 'test'. 52 download: Whether to download the data if it is not present. 53 """ 54 data_dir = os.path.join(path, split) 55 if os.path.exists(data_dir): 56 return 57 58 os.makedirs(path, exist_ok=True) 59 60 if split not in URLS: 61 raise ValueError(f"'{split}' is not a valid split choice.") 62 63 zip_name = "validation" if split == "val" else split 64 zip_path = os.path.join(path, f"leg_{zip_name}_data.zip") 65 util.download_source(path=zip_path, url=URLS[split], download=download, checksum=CHECKSUMS[split]) 66 util.unzip(zip_path=zip_path, dst=path) 67 68 69def _preprocess_labels(label_paths): 70 neu_label_paths = [] 71 for lpath in tqdm(label_paths, desc="Preprocessing labels"): 72 neu_label_path = lpath.replace(".mha", "_preprocessed.mha") 73 neu_label_paths.append(neu_label_path) 74 if os.path.exists(neu_label_path): 75 continue 76 77 import SimpleITK as sitk 78 79 labels = sitk.ReadImage(lpath) 80 larray = sitk.GetArrayFromImage(labels) 81 82 # NOTE: Remove other label ids not matching the specified task. 83 valid_labels = [100, 150, 200] 84 larray[~np.isin(larray, valid_labels)] = 0 85 86 for i, lid in enumerate(valid_labels, start=1): 87 larray[larray == lid] = i 88 89 sitk_label = sitk.GetImageFromArray(larray) 90 sitk.WriteImage(sitk_label, neu_label_path) 91 92 return neu_label_paths 93 94 95def get_leg_3d_us_paths( 96 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 97) -> Tuple[List[str], List[str]]: 98 """Get paths to the LEG 3D US data. 99 100 Args: 101 path: Filepath to a folder where the data is downloaded for further processing. 102 split: The data split to use. Either 'train', 'val' or 'test'. 103 download: Whether to download the data if it is not present. 104 105 Returns: 106 List of filepaths for the image data. 107 List of filepaths for the label data. 108 """ 109 get_leg_3d_us_data(path, split, download) 110 111 raw_paths = natsorted(glob(os.path.join(path, split, "*", "x*.mha"))) 112 label_paths = [fpath.replace("x", "masksX") for fpath in raw_paths] 113 label_paths = _preprocess_labels(label_paths) 114 115 assert len(raw_paths) == len(label_paths) 116 117 return raw_paths, label_paths 118 119 120def get_leg_3d_us_dataset( 121 path: Union[os.PathLike, str], 122 patch_shape: Tuple[int, ...], 123 split: Literal['train', 'val', 'test'], 124 resize_inputs: bool = False, 125 download: bool = False, 126 **kwargs 127) -> Dataset: 128 """Get the LEG 3D US dataset for leg muscle segmentation. 129 130 Args: 131 path: Filepath to a folder where the data is downloaded for further processing. 132 patch_shape: The patch shape to use for training. 133 split: The data split to use. Either 'train', 'val' or 'test'. 134 resize_inputs: Whether to resize inputs to the desired patch shape. 135 download: Whether to download the data if it is not present. 136 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 137 138 Returns: 139 The segmentation dataset. 140 """ 141 raw_paths, label_paths = get_leg_3d_us_paths(path, split, download) 142 143 if resize_inputs: 144 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 145 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 146 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 147 ) 148 149 dataset = torch_em.default_segmentation_dataset( 150 raw_paths=raw_paths, 151 raw_key=None, 152 label_paths=label_paths, 153 label_key=None, 154 patch_shape=patch_shape, 155 is_seg_dataset=True, 156 **kwargs 157 ) 158 159 for d in dataset.datasets: 160 d.max_sampling_attempts = 10000 161 162 return dataset 163 164 165def get_leg_3d_us_loader( 166 path: Union[os.PathLike, str], 167 batch_size: int, 168 patch_shape: Tuple[int, ...], 169 split: Literal['train', 'val', 'test'], 170 resize_inputs: bool = False, 171 download: bool = False, 172 **kwargs 173) -> DataLoader: 174 """Get the LEG 3D US dataloader for leg muscle segmentation. 175 176 Args: 177 path: Filepath to a folder where the data is downloaded for further processing. 178 batch_size: The batch size for training. 179 patch_shape: The patch shape to use for training. 180 split: The data split to use. Either 'train', 'val' or 'test'. 181 resize_inputs: Whether to resize inputs to the desired patch shape. 182 download: Whether to download the data if it is not present. 183 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 184 185 Returns: 186 The DataLoader. 187 """ 188 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 189 dataset = get_leg_3d_us_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 190 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
URLS =
{'train': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_train_data.zip', 'val': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_validation_data.zip', 'test': 'https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_test_data.zip'}
CHECKSUMS =
{'train': '747e9ada7135979218d93022ac46d40a3a85119e2ea7aebcda4b13f7dfda70d6', 'val': 'c204fa0759dd279de722a423401da60657bc0d1ab5f57d135cd0ad55c32af70f', 'test': '42ad341e8133f827d35f9cb3afde3ffbe5ae97dc2af448b6f9af6d4ea6ac99f0'}
def
get_leg_3d_us_data( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False):
45def get_leg_3d_us_data( 46 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 47): 48 """Download the LEG 3D US data. 49 50 Args: 51 path: Filepath to a folder where the data is downloaded for further processing. 52 split: The data split to use. Either 'train', 'val' or 'test'. 53 download: Whether to download the data if it is not present. 54 """ 55 data_dir = os.path.join(path, split) 56 if os.path.exists(data_dir): 57 return 58 59 os.makedirs(path, exist_ok=True) 60 61 if split not in URLS: 62 raise ValueError(f"'{split}' is not a valid split choice.") 63 64 zip_name = "validation" if split == "val" else split 65 zip_path = os.path.join(path, f"leg_{zip_name}_data.zip") 66 util.download_source(path=zip_path, url=URLS[split], download=download, checksum=CHECKSUMS[split]) 67 util.unzip(zip_path=zip_path, dst=path)
Download the LEG 3D US data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The data split to use. Either 'train', 'val' or 'test'.
- download: Whether to download the data if it is not present.
def
get_leg_3d_us_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> Tuple[List[str], List[str]]:
96def get_leg_3d_us_paths( 97 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 98) -> Tuple[List[str], List[str]]: 99 """Get paths to the LEG 3D US data. 100 101 Args: 102 path: Filepath to a folder where the data is downloaded for further processing. 103 split: The data split to use. Either 'train', 'val' or 'test'. 104 download: Whether to download the data if it is not present. 105 106 Returns: 107 List of filepaths for the image data. 108 List of filepaths for the label data. 109 """ 110 get_leg_3d_us_data(path, split, download) 111 112 raw_paths = natsorted(glob(os.path.join(path, split, "*", "x*.mha"))) 113 label_paths = [fpath.replace("x", "masksX") for fpath in raw_paths] 114 label_paths = _preprocess_labels(label_paths) 115 116 assert len(raw_paths) == len(label_paths) 117 118 return raw_paths, label_paths
Get paths to the LEG 3D US data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The data split to use. Either 'train', 'val' or 'test'.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the image data. List of filepaths for the label data.
def
get_leg_3d_us_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
121def get_leg_3d_us_dataset( 122 path: Union[os.PathLike, str], 123 patch_shape: Tuple[int, ...], 124 split: Literal['train', 'val', 'test'], 125 resize_inputs: bool = False, 126 download: bool = False, 127 **kwargs 128) -> Dataset: 129 """Get the LEG 3D US dataset for leg muscle segmentation. 130 131 Args: 132 path: Filepath to a folder where the data is downloaded for further processing. 133 patch_shape: The patch shape to use for training. 134 split: The data split to use. Either 'train', 'val' or 'test'. 135 resize_inputs: Whether to resize inputs to the desired patch shape. 136 download: Whether to download the data if it is not present. 137 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 138 139 Returns: 140 The segmentation dataset. 141 """ 142 raw_paths, label_paths = get_leg_3d_us_paths(path, split, download) 143 144 if resize_inputs: 145 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 146 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 147 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 148 ) 149 150 dataset = torch_em.default_segmentation_dataset( 151 raw_paths=raw_paths, 152 raw_key=None, 153 label_paths=label_paths, 154 label_key=None, 155 patch_shape=patch_shape, 156 is_seg_dataset=True, 157 **kwargs 158 ) 159 160 for d in dataset.datasets: 161 d.max_sampling_attempts = 10000 162 163 return dataset
Get the LEG 3D US dataset for leg muscle segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The data split to use. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_leg_3d_us_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
166def get_leg_3d_us_loader( 167 path: Union[os.PathLike, str], 168 batch_size: int, 169 patch_shape: Tuple[int, ...], 170 split: Literal['train', 'val', 'test'], 171 resize_inputs: bool = False, 172 download: bool = False, 173 **kwargs 174) -> DataLoader: 175 """Get the LEG 3D US dataloader for leg muscle segmentation. 176 177 Args: 178 path: Filepath to a folder where the data is downloaded for further processing. 179 batch_size: The batch size for training. 180 patch_shape: The patch shape to use for training. 181 split: The data split to use. Either 'train', 'val' or 'test'. 182 resize_inputs: Whether to resize inputs to the desired patch shape. 183 download: Whether to download the data if it is not present. 184 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 185 186 Returns: 187 The DataLoader. 188 """ 189 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 190 dataset = get_leg_3d_us_dataset(path, patch_shape, split, resize_inputs, download, **ds_kwargs) 191 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the LEG 3D US dataloader for leg muscle segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The data split to use. Either 'train', 'val' or 'test'.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.