torch_em.data.datasets.medical.lgg_mri

The LGG MRI datasets contains annotations for low grade glioma segmentation in FLAIR MRI scans.

The dataset is located at https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation. This dataset is from the publication https://www.nejm.org/doi/full/10.1056/NEJMoa1402121. Please cite it if you use this dataset in your research.

  1"""The LGG MRI datasets contains annotations for low grade glioma segmentation
  2in FLAIR MRI scans.
  3
  4The dataset is located at https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation.
  5This dataset is from the publication https://www.nejm.org/doi/full/10.1056/NEJMoa1402121.
  6Please cite it if you use this dataset in your research.
  7"""
  8
  9import os
 10import shutil
 11import warnings
 12from glob import glob
 13from tqdm import tqdm
 14from natsort import natsorted
 15from typing import Union, Tuple, List, Literal, Optional
 16
 17import numpy as np
 18import imageio.v3 as imageio
 19
 20from torch.utils.data import Dataset, DataLoader
 21
 22import torch_em
 23
 24from .. import util
 25
 26
 27def _merge_slices_to_volumes(path):
 28    import h5py
 29
 30    volume_dir = os.path.join(path, "data")
 31    os.makedirs(volume_dir, exist_ok=True)
 32
 33    patient_dirs = glob(os.path.join(path, "kaggle_3m", "TCGA_*"))
 34    for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"):
 35        label_slice_paths = natsorted(glob(os.path.join(patient_dir, "*_mask.tif")))
 36        raw_slice_paths = [lpath.replace("_mask.tif", ".tif") for lpath in label_slice_paths]
 37
 38        raw = [imageio.imread(rpath) for rpath in raw_slice_paths]
 39        labels = [imageio.imread(lpath) for lpath in label_slice_paths]
 40
 41        raw, labels = np.stack(raw, axis=0), np.stack(labels, axis=0)
 42
 43        volume_path = os.path.join(volume_dir, f"{os.path.basename(patient_dir)}.h5")
 44
 45        with h5py.File(volume_path, "w") as f:
 46            f.create_dataset("raw/pre_contrast", data=raw[..., 0], compression="gzip")
 47            f.create_dataset("raw/flair", data=raw[..., 1], compression="gzip")
 48            f.create_dataset("raw/post_contrast", data=raw[..., 2], compression="gzip")
 49            f.create_dataset("labels", data=labels, compression="gzip")
 50
 51    shutil.rmtree(os.path.join(path, "kaggle_3m"))
 52
 53
 54def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False):
 55    """Download the LGG MRI data.
 56
 57    Args:
 58        path: Filepath to a folder where the data is downloaded for further processing.
 59        download: Whether to download the data if it is not present.
 60    """
 61    data_dir = os.path.join(path, "data")
 62    if os.path.exists(data_dir):
 63        return
 64
 65    os.makedirs(path, exist_ok=True)
 66
 67    util.download_source_kaggle(path=path, dataset_name="mateuszbuda/lgg-mri-segmentation", download=download)
 68    zip_path = os.path.join(path, "lgg-mri-segmentation.zip")
 69    util.unzip(zip_path=zip_path, dst=path)
 70
 71    # Remove redundant volumes
 72    shutil.rmtree(os.path.join(path, "lgg-mri-segmentation"))
 73
 74    _merge_slices_to_volumes(path)
 75
 76
 77def get_lgg_mri_paths(
 78    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 79) -> List[str]:
 80    """Get paths to the LGG MRI data.
 81
 82    Args:
 83        path: Filepath to a folder where the data is downloaded for further processing.
 84        split: The choice of data split.
 85        download: Whether to download the data if it is not present.
 86
 87    Returns:
 88        List of filepaths for the input data.
 89    """
 90    get_lgg_mri_data(path, download)
 91
 92    volume_paths = natsorted(glob(os.path.join(path, "data", "*.h5")))
 93
 94    if split == "train":
 95        volume_paths = volume_paths[:70]
 96    elif split == "val":
 97        volume_paths = volume_paths[70:85]
 98    elif split == "test":
 99        volume_paths = volume_paths[85:]
100    else:
101        raise ValueError(f"'{split}' is not a valid split.")
102
103    return volume_paths
104
105
106def get_lgg_mri_dataset(
107    path: Union[os.PathLike, str],
108    patch_shape: Tuple[int, ...],
109    split: Literal['train', 'val', 'test'],
110    channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None,
111    resize_inputs: bool = False,
112    download: bool = False,
113    **kwargs
114) -> Dataset:
115    """Get the LGG MRI dataset for glioma segmentation.
116
117    Args:
118        path: Filepath to a folder where the data is downloaded for further processing.
119        patch_shape: The patch shape to use for training.
120        split: The choice of data split.
121        channels: The choice of modality as input channel.
122        resize_inputs:  Whether to resize inputs to the desired patch shape.
123        download: Whether to download the data if it is not present.
124        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
125
126    Returns:
127        The segmentation dataset.
128    """
129    volume_paths = get_lgg_mri_paths(path, split, download)
130
131    if resize_inputs and channels is not None:
132        if channels is None:
133            warnings.warn("The default for channels is set to 'None'. Choose one specific channel for resizing inputs.")
134
135        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
136        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
137            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
138        )
139
140    available_channels = ["pre_contrast", "flair", "post_contrast"]
141    if channels is not None and channels not in available_channels:
142        raise ValueError(f"'{channels}' is not a valid channel.")
143
144    return torch_em.default_segmentation_dataset(
145        raw_paths=volume_paths,
146        raw_key=[f"raw/{chan}" for chan in available_channels] if channels is None else f"raw/{channels}",
147        label_paths=volume_paths,
148        label_key="labels",
149        patch_shape=patch_shape,
150        is_seg_dataset=True,
151        with_channels=True if channels is None else False,
152        **kwargs
153    )
154
155
156def get_lgg_mri_loader(
157    path: Union[os.PathLike, str],
158    batch_size: int,
159    patch_shape: Tuple[int, ...],
160    split: Literal['train', 'val', 'test'],
161    channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None,
162    resize_inputs: bool = False,
163    download: bool = False,
164    **kwargs
165) -> DataLoader:
166    """Get the LGG MRI dataloader for glioma segmentation.
167
168    Args:
169        path: Filepath to a folder where the data is downloaded for further processing.
170        batch_size: The batch size for training.
171        patch_shape: The patch shape to use for training.
172        split: The choice of data split.
173        channels: The choice of modality as input channel.
174        resize_inputs:  Whether to resize inputs to the desired patch shape.
175        download: Whether to download the data if it is not present.
176        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
177
178    Returns:
179        The DataLoader.
180    """
181    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
182    dataset = get_lgg_mri_dataset(path, patch_shape, split, channels, resize_inputs, download, **ds_kwargs)
183    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False):
55def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False):
56    """Download the LGG MRI data.
57
58    Args:
59        path: Filepath to a folder where the data is downloaded for further processing.
60        download: Whether to download the data if it is not present.
61    """
62    data_dir = os.path.join(path, "data")
63    if os.path.exists(data_dir):
64        return
65
66    os.makedirs(path, exist_ok=True)
67
68    util.download_source_kaggle(path=path, dataset_name="mateuszbuda/lgg-mri-segmentation", download=download)
69    zip_path = os.path.join(path, "lgg-mri-segmentation.zip")
70    util.unzip(zip_path=zip_path, dst=path)
71
72    # Remove redundant volumes
73    shutil.rmtree(os.path.join(path, "lgg-mri-segmentation"))
74
75    _merge_slices_to_volumes(path)

Download the LGG MRI data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • download: Whether to download the data if it is not present.
def get_lgg_mri_paths( path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False) -> List[str]:
 78def get_lgg_mri_paths(
 79    path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False
 80) -> List[str]:
 81    """Get paths to the LGG MRI data.
 82
 83    Args:
 84        path: Filepath to a folder where the data is downloaded for further processing.
 85        split: The choice of data split.
 86        download: Whether to download the data if it is not present.
 87
 88    Returns:
 89        List of filepaths for the input data.
 90    """
 91    get_lgg_mri_data(path, download)
 92
 93    volume_paths = natsorted(glob(os.path.join(path, "data", "*.h5")))
 94
 95    if split == "train":
 96        volume_paths = volume_paths[:70]
 97    elif split == "val":
 98        volume_paths = volume_paths[70:85]
 99    elif split == "test":
100        volume_paths = volume_paths[85:]
101    else:
102        raise ValueError(f"'{split}' is not a valid split.")
103
104    return volume_paths

Get paths to the LGG MRI data.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • split: The choice of data split.
  • download: Whether to download the data if it is not present.
Returns:

List of filepaths for the input data.

def get_lgg_mri_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataset.Dataset:
107def get_lgg_mri_dataset(
108    path: Union[os.PathLike, str],
109    patch_shape: Tuple[int, ...],
110    split: Literal['train', 'val', 'test'],
111    channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None,
112    resize_inputs: bool = False,
113    download: bool = False,
114    **kwargs
115) -> Dataset:
116    """Get the LGG MRI dataset for glioma segmentation.
117
118    Args:
119        path: Filepath to a folder where the data is downloaded for further processing.
120        patch_shape: The patch shape to use for training.
121        split: The choice of data split.
122        channels: The choice of modality as input channel.
123        resize_inputs:  Whether to resize inputs to the desired patch shape.
124        download: Whether to download the data if it is not present.
125        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
126
127    Returns:
128        The segmentation dataset.
129    """
130    volume_paths = get_lgg_mri_paths(path, split, download)
131
132    if resize_inputs and channels is not None:
133        if channels is None:
134            warnings.warn("The default for channels is set to 'None'. Choose one specific channel for resizing inputs.")
135
136        resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
137        kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
138            kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
139        )
140
141    available_channels = ["pre_contrast", "flair", "post_contrast"]
142    if channels is not None and channels not in available_channels:
143        raise ValueError(f"'{channels}' is not a valid channel.")
144
145    return torch_em.default_segmentation_dataset(
146        raw_paths=volume_paths,
147        raw_key=[f"raw/{chan}" for chan in available_channels] if channels is None else f"raw/{channels}",
148        label_paths=volume_paths,
149        label_key="labels",
150        patch_shape=patch_shape,
151        is_seg_dataset=True,
152        with_channels=True if channels is None else False,
153        **kwargs
154    )

Get the LGG MRI dataset for glioma segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • channels: The choice of modality as input channel.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def get_lgg_mri_loader( path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], split: Literal['train', 'val', 'test'], channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None, resize_inputs: bool = False, download: bool = False, **kwargs) -> torch.utils.data.dataloader.DataLoader:
157def get_lgg_mri_loader(
158    path: Union[os.PathLike, str],
159    batch_size: int,
160    patch_shape: Tuple[int, ...],
161    split: Literal['train', 'val', 'test'],
162    channels: Optional[Literal['pre_contrast', 'flair', 'post_contrast']] = None,
163    resize_inputs: bool = False,
164    download: bool = False,
165    **kwargs
166) -> DataLoader:
167    """Get the LGG MRI dataloader for glioma segmentation.
168
169    Args:
170        path: Filepath to a folder where the data is downloaded for further processing.
171        batch_size: The batch size for training.
172        patch_shape: The patch shape to use for training.
173        split: The choice of data split.
174        channels: The choice of modality as input channel.
175        resize_inputs:  Whether to resize inputs to the desired patch shape.
176        download: Whether to download the data if it is not present.
177        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
178
179    Returns:
180        The DataLoader.
181    """
182    ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
183    dataset = get_lgg_mri_dataset(path, patch_shape, split, channels, resize_inputs, download, **ds_kwargs)
184    return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

Get the LGG MRI dataloader for glioma segmentation.

Arguments:
  • path: Filepath to a folder where the data is downloaded for further processing.
  • batch_size: The batch size for training.
  • patch_shape: The patch shape to use for training.
  • split: The choice of data split.
  • channels: The choice of modality as input channel.
  • resize_inputs: Whether to resize inputs to the desired patch shape.
  • download: Whether to download the data if it is not present.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.