torch_em.data.datasets.medical.lgg_mri
The LGG MRI datasets contains annotations for low grade glioma segmentation in FLAIR MRI scans.
The dataset is located at https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation. This dataset is from the publication https://www.nejm.org/doi/full/10.1056/NEJMoa1402121. Please cite it if you use this dataset in your research.
1"""The LGG MRI datasets contains annotations for low grade glioma segmentation 2in FLAIR MRI scans. 3 4The dataset is located at https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation. 5This dataset is from the publication https://www.nejm.org/doi/full/10.1056/NEJMoa1402121. 6Please cite it if you use this dataset in your research. 7""" 8 9import os 10import shutil 11import warnings 12from glob import glob 13from tqdm import tqdm 14from natsort import natsorted 15from typing import Union, Tuple, List, Literal, Optional 16 17import numpy as np 18import imageio.v3 as imageio 19 20from torch.utils.data import Dataset, DataLoader 21 22import torch_em 23 24from .. import util 25 26 27def _merge_slices_to_volumes(path): 28 import h5py 29 30 volume_dir = os.path.join(path, "data") 31 os.makedirs(volume_dir, exist_ok=True) 32 33 patient_dirs = glob(os.path.join(path, "kaggle_3m", "TCGA_*")) 34 for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"): 35 label_slice_paths = natsorted(glob(os.path.join(patient_dir, "*_mask.tif"))) 36 raw_slice_paths = [lpath.replace("_mask.tif", ".tif") for lpath in label_slice_paths] 37 38 raw = [imageio.imread(rpath) for rpath in raw_slice_paths] 39 labels = [imageio.imread(lpath) for lpath in label_slice_paths] 40 41 raw, labels = np.stack(raw, axis=0), np.stack(labels, axis=0) 42 43 volume_path = os.path.join(volume_dir, f"{os.path.basename(patient_dir)}.h5") 44 45 with h5py.File(volume_path, "w") as f: 46 f.create_dataset("raw/pre_contrast", data=raw[..., 0], compression="gzip") 47 f.create_dataset("raw/flair", data=raw[..., 1], compression="gzip") 48 f.create_dataset("raw/post_contrast", data=raw[..., 2], compression="gzip") 49 f.create_dataset("labels", data=labels, compression="gzip") 50 51 shutil.rmtree(os.path.join(path, "kaggle_3m")) 52 53 54def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False): 55 """Download the LGG MRI data. 56 57 Args: 58 path: Filepath to a folder where the data is downloaded for further processing. 59 download: Whether to download the data if it is not present. 60 """ 61 data_dir = os.path.join(path, "data") 62 if os.path.exists(data_dir): 63 return 64 65 os.makedirs(path, exist_ok=True) 66 67 util.download_source_kaggle(path=path, dataset_name="mateuszbuda/lgg-mri-segmentation", download=download) 68 zip_path = os.path.join(path, "lgg-mri-segmentation.zip") 69 util.unzip(zip_path=zip_path, dst=path) 70 71 # Remove redundant volumes 72 shutil.rmtree(os.path.join(path, "lgg-mri-segmentation")) 73 74 _merge_slices_to_volumes(path) 75 76 77def get_lgg_mri_paths( 78 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 79) -> List[str]: 80 """Get paths to the LGG MRI data. 81 82 Args: 83 path: Filepath to a folder where the data is downloaded for further processing. 84 split: The choice of data split. 85 download: Whether to download the data if it is not present. 86 87 Returns: 88 List of filepaths for the input data. 89 """ 90 get_lgg_mri_data(path, download) 91 92 volume_paths = natsorted(glob(os.path.join(path, "data", "*.h5"))) 93 94 if split == "train": 95 volume_paths = volume_paths[:70] 96 elif split == "val": 97 volume_paths = volume_paths[70:85] 98 elif split == "test": 99 volume_paths = volume_paths[85:] 100 else: 101 raise ValueError(f"'{split}' is not a valid split.") 102 103 return volume_paths 104 105 106def get_lgg_mri_dataset( 107 path: Union[os.PathLike, str], 108 patch_shape: Tuple[int, ...], 109 split: Literal['train', 'val', 'test'], 110 channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, 111 resize_inputs: bool = False, 112 download: bool = False, 113 **kwargs 114) -> Dataset: 115 """Get the LGG MRI dataset for glioma segmentation. 116 117 Args: 118 path: Filepath to a folder where the data is downloaded for further processing. 119 patch_shape: The patch shape to use for training. 120 split: The choice of data split. 121 channels: The choice of modality as input channel. 122 resize_inputs: Whether to resize inputs to the desired patch shape. 123 download: Whether to download the data if it is not present. 124 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 125 126 Returns: 127 The segmentation dataset. 128 """ 129 volume_paths = get_lgg_mri_paths(path, split, download) 130 131 if resize_inputs and channels is not None: 132 if channels is None: 133 warnings.warn("The default for channels is set to 'None'. Choose one specific channel for resizing inputs.") 134 135 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 136 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 137 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 138 ) 139 140 available_channels = ["pre_contrast", "flair", "post_contrast"] 141 if channels is not None and channels not in available_channels: 142 raise ValueError(f"'{channels}' is not a valid channel.") 143 144 return torch_em.default_segmentation_dataset( 145 raw_paths=volume_paths, 146 raw_key=[f"raw/{chan}" for chan in available_channels] if channels is None else f"raw/{channels}", 147 label_paths=volume_paths, 148 label_key="labels", 149 patch_shape=patch_shape, 150 is_seg_dataset=True, 151 with_channels=True if channels is None else False, 152 **kwargs 153 ) 154 155 156def get_lgg_mri_loader( 157 path: Union[os.PathLike, str], 158 batch_size: int, 159 patch_shape: Tuple[int, ...], 160 split: Literal['train', 'val', 'test'], 161 channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, 162 resize_inputs: bool = False, 163 download: bool = False, 164 **kwargs 165) -> DataLoader: 166 """Get the LGG MRI dataloader for glioma segmentation. 167 168 Args: 169 path: Filepath to a folder where the data is downloaded for further processing. 170 batch_size: The batch size for training. 171 patch_shape: The patch shape to use for training. 172 split: The choice of data split. 173 channels: The choice of modality as input channel. 174 resize_inputs: Whether to resize inputs to the desired patch shape. 175 download: Whether to download the data if it is not present. 176 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 177 178 Returns: 179 The DataLoader. 180 """ 181 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 182 dataset = get_lgg_mri_dataset(path, patch_shape, split, channels, resize_inputs, download, **ds_kwargs) 183 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def
get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False):
55def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False): 56 """Download the LGG MRI data. 57 58 Args: 59 path: Filepath to a folder where the data is downloaded for further processing. 60 download: Whether to download the data if it is not present. 61 """ 62 data_dir = os.path.join(path, "data") 63 if os.path.exists(data_dir): 64 return 65 66 os.makedirs(path, exist_ok=True) 67 68 util.download_source_kaggle(path=path, dataset_name="mateuszbuda/lgg-mri-segmentation", download=download) 69 zip_path = os.path.join(path, "lgg-mri-segmentation.zip") 70 util.unzip(zip_path=zip_path, dst=path) 71 72 # Remove redundant volumes 73 shutil.rmtree(os.path.join(path, "lgg-mri-segmentation")) 74 75 _merge_slices_to_volumes(path)
Download the LGG MRI data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- download: Whether to download the data if it is not present.
def
get_lgg_mri_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
78def get_lgg_mri_paths( 79 path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False 80) -> List[str]: 81 """Get paths to the LGG MRI data. 82 83 Args: 84 path: Filepath to a folder where the data is downloaded for further processing. 85 split: The choice of data split. 86 download: Whether to download the data if it is not present. 87 88 Returns: 89 List of filepaths for the input data. 90 """ 91 get_lgg_mri_data(path, download) 92 93 volume_paths = natsorted(glob(os.path.join(path, "data", "*.h5"))) 94 95 if split == "train": 96 volume_paths = volume_paths[:70] 97 elif split == "val": 98 volume_paths = volume_paths[70:85] 99 elif split == "test": 100 volume_paths = volume_paths[85:] 101 else: 102 raise ValueError(f"'{split}' is not a valid split.") 103 104 return volume_paths
Get paths to the LGG MRI data.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- split: The choice of data split.
- download: Whether to download the data if it is not present.
Returns:
List of filepaths for the input data.
def
get_lgg_mri_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
107def get_lgg_mri_dataset( 108 path: Union[os.PathLike, str], 109 patch_shape: Tuple[int, ...], 110 split: Literal['train', 'val', 'test'], 111 channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, 112 resize_inputs: bool = False, 113 download: bool = False, 114 **kwargs 115) -> Dataset: 116 """Get the LGG MRI dataset for glioma segmentation. 117 118 Args: 119 path: Filepath to a folder where the data is downloaded for further processing. 120 patch_shape: The patch shape to use for training. 121 split: The choice of data split. 122 channels: The choice of modality as input channel. 123 resize_inputs: Whether to resize inputs to the desired patch shape. 124 download: Whether to download the data if it is not present. 125 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 126 127 Returns: 128 The segmentation dataset. 129 """ 130 volume_paths = get_lgg_mri_paths(path, split, download) 131 132 if resize_inputs and channels is not None: 133 if channels is None: 134 warnings.warn("The default for channels is set to 'None'. Choose one specific channel for resizing inputs.") 135 136 resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} 137 kwargs, patch_shape = util.update_kwargs_for_resize_trafo( 138 kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs 139 ) 140 141 available_channels = ["pre_contrast", "flair", "post_contrast"] 142 if channels is not None and channels not in available_channels: 143 raise ValueError(f"'{channels}' is not a valid channel.") 144 145 return torch_em.default_segmentation_dataset( 146 raw_paths=volume_paths, 147 raw_key=[f"raw/{chan}" for chan in available_channels] if channels is None else f"raw/{channels}", 148 label_paths=volume_paths, 149 label_key="labels", 150 patch_shape=patch_shape, 151 is_seg_dataset=True, 152 with_channels=True if channels is None else False, 153 **kwargs 154 )
Get the LGG MRI dataset for glioma segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- channels: The choice of modality as input channel.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
get_lgg_mri_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
157def get_lgg_mri_loader( 158 path: Union[os.PathLike, str], 159 batch_size: int, 160 patch_shape: Tuple[int, ...], 161 split: Literal['train', 'val', 'test'], 162 channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, 163 resize_inputs: bool = False, 164 download: bool = False, 165 **kwargs 166) -> DataLoader: 167 """Get the LGG MRI dataloader for glioma segmentation. 168 169 Args: 170 path: Filepath to a folder where the data is downloaded for further processing. 171 batch_size: The batch size for training. 172 patch_shape: The patch shape to use for training. 173 split: The choice of data split. 174 channels: The choice of modality as input channel. 175 resize_inputs: Whether to resize inputs to the desired patch shape. 176 download: Whether to download the data if it is not present. 177 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. 178 179 Returns: 180 The DataLoader. 181 """ 182 ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) 183 dataset = get_lgg_mri_dataset(path, patch_shape, split, channels, resize_inputs, download, **ds_kwargs) 184 return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Get the LGG MRI dataloader for glioma segmentation.
Arguments:
- path: Filepath to a folder where the data is downloaded for further processing.
- batch_size: The batch size for training.
- patch_shape: The patch shape to use for training.
- split: The choice of data split.
- channels: The choice of modality as input channel.
- resize_inputs: Whether to resize inputs to the desired patch shape.
- download: Whether to download the data if it is not present.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.