torch_em.segmentation

  1import os
  2from glob import glob
  3from typing import Any, Dict, Optional, Union, Tuple, List, Callable
  4
  5import numpy as np
  6import torch
  7import torch.utils.data
  8from torch.utils.data import DataLoader
  9
 10from .loss import DiceLoss
 11from .util import load_data
 12from .trainer import DefaultTrainer
 13from .trainer.tensorboard_logger import TensorboardLogger
 14from .transform import get_augmentations, get_raw_transform
 15from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset, TensorDataset
 16
 17
 18# TODO add a heuristic to estimate this from the number of epochs
 19DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5}
 20"""@private
 21"""
 22
 23
 24#
 25# convenience functions for segmentation loaders
 26#
 27
 28# TODO implement balanced and make it the default
 29# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"):
 30def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
 31    """@private
 32    """
 33    assert split in ("balanced", "uniform")
 34    n_datasets = len(raw_paths)
 35    if split == "uniform":
 36        # even distribution of samples to datasets
 37        samples_per_ds = n_samples // n_datasets
 38        divider = n_samples % n_datasets
 39        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
 40    else:
 41        # distribution of samples to dataset based on the dataset lens
 42        raise NotImplementedError
 43
 44
 45def check_paths(raw_paths, label_paths):
 46    """@private
 47    """
 48    if not isinstance(raw_paths, type(label_paths)):
 49        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
 50
 51    # This is a tensor dataset and we don't need to verify the paths.
 52    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (torch.Tensor, np.ndarray)):
 53        return
 54
 55    def _check_path(path):
 56        if isinstance(path, str):
 57            if not os.path.exists(path):
 58                raise ValueError(f"Could not find path {path}")
 59        else:
 60            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
 61            for per_path in path:
 62                if not os.path.exists(per_path):
 63                    raise ValueError(f"Could not find path {per_path}")
 64
 65    if isinstance(raw_paths, str):
 66        _check_path(raw_paths)
 67        _check_path(label_paths)
 68    else:
 69        if len(raw_paths) != len(label_paths):
 70            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
 71        for rp, lp in zip(raw_paths, label_paths):
 72            _check_path(rp)
 73            _check_path(lp)
 74
 75
 76# Check if we can load the data as SegmentationDataset.
 77def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
 78    """@private
 79    """
 80    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)):
 81        if not all(isinstance(rp, (np.ndarray, torch.Tensor)) for rp in raw_paths):
 82            raise ValueError("Inconsistent raw data")
 83        if not all(isinstance(lp, (np.ndarray, torch.Tensor)) for lp in label_paths):
 84            raise ValueError("Inconsistent label data")
 85        return False
 86
 87    def _can_open(path, key):
 88        try:
 89            load_data(path, key)
 90            return True
 91        except Exception:
 92            return False
 93
 94    if isinstance(raw_paths, str):
 95        can_open_raw = _can_open(raw_paths, raw_key)
 96        can_open_label = _can_open(label_paths, label_key)
 97    else:
 98        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
 99        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
100            raise ValueError("Inconsistent raw data")
101        can_open_raw = can_open_raw[0]
102
103        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
104        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
105            raise ValueError("Inconsistent label data")
106        can_open_label = can_open_label[0]
107
108    if can_open_raw != can_open_label:
109        raise ValueError("Inconsistent raw and label data")
110
111    return can_open_raw
112
113
114def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs):
115    rois = kwargs.pop("rois", None)
116    if isinstance(raw_paths, str):
117        if rois is not None:
118            assert isinstance(rois, (tuple, slice))
119            if isinstance(rois, tuple):
120                assert all(isinstance(roi, slice) for roi in rois)
121        ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
122    else:
123        assert len(raw_paths) > 0
124        if rois is not None:
125            assert len(rois) == len(label_paths)
126            assert all(isinstance(roi, tuple) for roi in rois), f"{rois}"
127        n_samples = kwargs.pop("n_samples", None)
128
129        samples_per_ds = (
130            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
131        )
132        ds = []
133        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
134            roi = None if rois is None else rois[i]
135            dset = SegmentationDataset(
136                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs
137            )
138            ds.append(dset)
139        ds = ConcatDataset(*ds)
140    return ds
141
142
143def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, with_channels, **kwargs):
144    if isinstance(raw_paths[0], (torch.Tensor, np.ndarray)):
145        assert raw_key is None and label_key is None
146        assert roi is None
147        kwargs.pop("pre_label_transform")  # NOTE: The 'TensorDataset' currently does not support samplers.
148        return TensorDataset(raw_paths, label_paths, with_channels=with_channels, **kwargs)
149
150    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
151        rpath = glob(os.path.join(rpath, rkey))
152        rpath.sort()
153        if len(rpath) == 0:
154            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
155
156        lpath = glob(os.path.join(lpath, lkey))
157        lpath.sort()
158        if len(rpath) != len(lpath):
159            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
160
161        if this_roi is not None:
162            rpath, lpath = rpath[roi], lpath[roi]
163
164        return rpath, lpath
165
166    patch_shape = kwargs.pop("patch_shape")
167    if patch_shape is not None:
168        if len(patch_shape) == 3:
169            if patch_shape[0] != 1:
170                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
171            patch_shape = patch_shape[1:]
172        assert len(patch_shape) == 2
173
174    if isinstance(raw_paths, str):
175        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
176        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
177
178    elif raw_key is None:
179        assert label_key is None
180        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
181        assert len(raw_paths) == len(label_paths)
182        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
183
184    else:
185        ds = []
186        n_samples = kwargs.pop("n_samples", None)
187        samples_per_ds = (
188            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
189        )
190        if roi is None:
191            roi = len(raw_paths) * [None]
192        assert len(roi) == len(raw_paths)
193        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
194            print(raw_path, label_path, this_roi)
195            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
196            dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs)
197            ds.append(dset)
198        ds = ConcatDataset(*ds)
199
200    return ds
201
202
203def _get_default_transform(path, key, is_seg_dataset, ndim):
204    if is_seg_dataset and ndim is None:
205        shape = load_data(path, key).shape
206        if len(shape) == 2:
207            ndim = 2
208        else:
209            # heuristics to figure out whether to use default 3d
210            # or default anisotropic augmentations
211            ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
212
213    elif is_seg_dataset and ndim is not None:
214        pass
215
216    else:
217        ndim = 2
218
219    return get_augmentations(ndim)
220
221
222def default_segmentation_loader(
223    raw_paths: Union[List[Any], str, os.PathLike],
224    raw_key: Optional[str],
225    label_paths: Union[List[Any], str, os.PathLike],
226    label_key: Optional[str],
227    batch_size: int,
228    patch_shape: Tuple[int, ...],
229    label_transform: Optional[Callable] = None,
230    label_transform2: Optional[Callable] = None,
231    raw_transform: Optional[Callable] = None,
232    transform: Optional[Callable] = None,
233    dtype: torch.device = torch.float32,
234    label_dtype: torch.device = torch.float32,
235    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
236    n_samples: Optional[int] = None,
237    sampler: Optional[Callable] = None,
238    ndim: Optional[int] = None,
239    is_seg_dataset: Optional[bool] = None,
240    with_channels: bool = False,
241    with_label_channels: bool = False,
242    verify_paths: bool = True,
243    with_padding: bool = True,
244    z_ext: Optional[int] = None,
245    pre_label_transform: Optional[Callable] = None,
246    **loader_kwargs,
247) -> torch.utils.data.DataLoader:
248    """Get data loader for training a segmentation network.
249
250    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
251    on the data formats that are supported.
252
253    Args:
254        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
255            This argument also accepts a list of numpy arrays or torch tensors.
256        raw_key: The name of the internal dataset containing the raw data.
257            Set to None for regular image files, numpy arrays, or torch tensors.
258        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
259            This argument also accepts a list of numpy arrays or torch tensors.
260        label_key: The name of the internal dataset containing the raw data.
261            Set to None for regular image files, numpy arrays, or torch tensors.
262        batch_size: The batch size for the data loader.
263        patch_shape: The patch shape for the training samples.
264        label_transform: Transformation applied to the label data of a sample,
265            before applying augmentations via `transform`.
266        label_transform2: Transformation applied to the label data of a sample,
267            after applying augmentations via `transform`.
268        raw_transform: Transformation applied to the raw data of a sample,
269            before applying augmentations via `transform`.
270        transform: Transformation applied to both the raw data and label data of a sample.
271            This can be used to implement data augmentations.
272        dtype: The return data type of the raw data.
273        label_dtype: The return data type of the label data.
274        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
275        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
276        sampler: Sampler for rejecting samples according to a defined criterion.
277            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
278        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
279        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
280            If None, the type of dataset will be derived from the data.
281        with_channels: Whether the raw data has channels.
282        with_label_channels: Whether the label data has channels.
283        verify_paths: Whether to verify all paths before creating the dataset.
284        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
285        z_ext: Extra bounding box for loading the data across z.
286        pre_label_transform: Transformation applied to the label data of a chosen random sample,
287            before applying the sample validity via the `sampler`.
288        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
289
290    Returns:
291        The torch data loader.
292    """
293    ds = default_segmentation_dataset(
294        raw_paths=raw_paths,
295        raw_key=raw_key,
296        label_paths=label_paths,
297        label_key=label_key,
298        patch_shape=patch_shape,
299        label_transform=label_transform,
300        label_transform2=label_transform2,
301        raw_transform=raw_transform,
302        transform=transform,
303        dtype=dtype,
304        label_dtype=label_dtype,
305        rois=rois,
306        n_samples=n_samples,
307        sampler=sampler,
308        ndim=ndim,
309        is_seg_dataset=is_seg_dataset,
310        with_channels=with_channels,
311        with_label_channels=with_label_channels,
312        with_padding=with_padding,
313        z_ext=z_ext,
314        verify_paths=verify_paths,
315        pre_label_transform=pre_label_transform,
316    )
317    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
318
319
320def default_segmentation_dataset(
321    raw_paths: Union[List[Any], str, os.PathLike],
322    raw_key: Optional[str],
323    label_paths: Union[List[Any], str, os.PathLike],
324    label_key: Optional[str],
325    patch_shape: Tuple[int, ...],
326    label_transform: Optional[Callable] = None,
327    label_transform2: Optional[Callable] = None,
328    raw_transform: Optional[Callable] = None,
329    transform: Optional[Callable] = None,
330    dtype: torch.dtype = torch.float32,
331    label_dtype: torch.dtype = torch.float32,
332    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
333    n_samples: Optional[int] = None,
334    sampler: Optional[Callable] = None,
335    ndim: Optional[int] = None,
336    is_seg_dataset: Optional[bool] = None,
337    with_channels: bool = False,
338    with_label_channels: bool = False,
339    verify_paths: bool = True,
340    with_padding: bool = True,
341    z_ext: Optional[int] = None,
342    pre_label_transform: Optional[Callable] = None,
343) -> torch.utils.data.Dataset:
344    """Get data set for training a segmentation network.
345
346    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
347    on the data formats that are supported.
348
349    Args:
350        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
351            This argument also accepts a list of numpy arrays or torch tensors.
352        raw_key: The name of the internal dataset containing the raw data.
353            Set to None for regular image files, numpy arrays, or torch tensors.
354        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
355            This argument also accepts a list of numpy arrays or torch tensors.
356        label_key: The name of the internal dataset containing the raw data.
357            Set to None for regular image files, numpy arrays, or torch tensors.
358        patch_shape: The patch shape for the training samples.
359        label_transform: Transformation applied to the label data of a sample,
360            before applying augmentations via `transform`.
361        label_transform2: Transformation applied to the label data of a sample,
362            after applying augmentations via `transform`.
363        raw_transform: Transformation applied to the raw data of a sample,
364            before applying augmentations via `transform`.
365        transform: Transformation applied to both the raw data and label data of a sample.
366            This can be used to implement data augmentations.
367        dtype: The return data type of the raw data.
368        label_dtype: The return data type of the label data.
369        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
370        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
371        sampler: Sampler for rejecting samples according to a defined criterion.
372            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
373        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
374        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
375            If None, the type of dataset will be derived from the data.
376        with_channels: Whether the raw data has channels.
377        with_label_channels: Whether the label data has channels.
378        verify_paths: Whether to verify all paths before creating the dataset.
379        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
380        z_ext: Extra bounding box for loading the data across z.
381        pre_label_transform: Transformation applied to the label data of a chosen random sample,
382            before applying the sample validity via the `sampler`.
383
384    Returns:
385        The torch dataset.
386    """
387    if verify_paths:
388        check_paths(raw_paths, label_paths)
389
390    if is_seg_dataset is None:
391        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
392
393    # We always use a raw transform in the convenience function.
394    if raw_transform is None:
395        raw_transform = get_raw_transform()
396
397    # We always use augmentations in the convenience function.
398    if transform is None:
399        transform = _get_default_transform(
400            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
401        )
402
403    if is_seg_dataset:
404        ds = _load_segmentation_dataset(
405            raw_paths,
406            raw_key,
407            label_paths,
408            label_key,
409            patch_shape=patch_shape,
410            raw_transform=raw_transform,
411            label_transform=label_transform,
412            label_transform2=label_transform2,
413            transform=transform,
414            rois=rois,
415            n_samples=n_samples,
416            sampler=sampler,
417            ndim=ndim,
418            dtype=dtype,
419            label_dtype=label_dtype,
420            with_channels=with_channels,
421            with_label_channels=with_label_channels,
422            with_padding=with_padding,
423            z_ext=z_ext,
424            pre_label_transform=pre_label_transform,
425        )
426
427    else:
428        ds = _load_image_collection_dataset(
429            raw_paths,
430            raw_key,
431            label_paths,
432            label_key,
433            roi=rois,
434            patch_shape=patch_shape,
435            label_transform=label_transform,
436            raw_transform=raw_transform,
437            label_transform2=label_transform2,
438            transform=transform,
439            n_samples=n_samples,
440            sampler=sampler,
441            dtype=dtype,
442            label_dtype=label_dtype,
443            with_padding=with_padding,
444            with_channels=with_channels,
445            pre_label_transform=pre_label_transform,
446        )
447
448    return ds
449
450
451def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader:
452    """@private
453    """
454    pin_memory = loader_kwargs.pop("pin_memory", True)
455    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs)
456    # monkey patch shuffle attribute to the loader
457    loader.shuffle = loader_kwargs.get("shuffle", False)
458    return loader
459
460
461#
462# convenience functions for segmentation trainers
463#
464
465
466def default_segmentation_trainer(
467    name: str,
468    model: torch.nn.Module,
469    train_loader: DataLoader,
470    val_loader: DataLoader,
471    loss: Optional[torch.nn.Module] = None,
472    metric: Optional[Callable] = None,
473    learning_rate: float = 1e-3,
474    device: Optional[Union[str, torch.device]] = None,
475    log_image_interval: int = 100,
476    mixed_precision: bool = True,
477    early_stopping: Optional[int] = None,
478    logger=TensorboardLogger,
479    logger_kwargs: Optional[Dict[str, Any]] = None,
480    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
481    optimizer_kwargs: Dict[str, Any] = {},
482    trainer_class=DefaultTrainer,
483    id_: Optional[str] = None,
484    save_root: Optional[str] = None,
485    compile_model: Optional[Union[bool, str]] = None,
486    rank: Optional[int] = None,
487):
488    """Get a trainer for a segmentation network.
489
490    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
491    By default, it uses the dice score as loss and metric.
492    This can be changed by passing arguments for `loss` and/or `metric`.
493    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
494
495    Here's an example for training a 2D U-Net with this function:
496    ```python
497    import torch_em
498    from torch_em.model import UNet2d
499    from torch_em.data.datasets.light_microscopy import get_dsb_loader
500
501    # The training data will be downloaded to this location.
502    data_root = "/path/to/save/the/training/data"
503    patch_shape = (256, 256)
504    trainer = default_segmentation_trainer(
505        name="unet-training"
506        model=UNet2d(in_channels=1, out_channels=1)
507        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
508        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
509    )
510    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
511    ```
512
513    Args:
514        name: The name of the checkpoint that will be created by the trainer.
515        model: The model to train.
516        train_loader: The data loader containing the training data.
517        val_loader: The data loader containing the validation data.
518        loss: The loss function for training.
519        metric: The metric for validation.
520        learning_rate: The initial learning rate for the AdamW optimizer.
521        device: The torch device to use for training. If None, will use a GPU if available.
522        log_image_interval: The interval for saving images during logging, in training iterations.
523        mixed_precision: Whether to train with mixed precision.
524        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
525        logger: The logger class. Will be instantiated for logging.
526            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
527        logger_kwargs: The keyword arguments for the logger class.
528        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
529        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
530        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
531            but can be set to a custom trainer class to enable custom training procedures.
532        id_: Unique identifier for the trainer. If None then `name` will be used.
533        save_root: The root folder for saving the checkpoint and logs.
534        compile_model: Whether to compile the model before training.
535        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
536
537    Returns:
538        The trainer.
539    """
540    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
541    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
542
543    loss = DiceLoss() if loss is None else loss
544    metric = DiceLoss() if metric is None else metric
545
546    if device is None:
547        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
548    else:
549        device = torch.device(device)
550
551    # CPU does not support mixed precision training.
552    if device.type == "cpu":
553        mixed_precision = False
554
555    return trainer_class(
556        name=name,
557        model=model,
558        train_loader=train_loader,
559        val_loader=val_loader,
560        loss=loss,
561        metric=metric,
562        optimizer=optimizer,
563        device=device,
564        lr_scheduler=scheduler,
565        mixed_precision=mixed_precision,
566        early_stopping=early_stopping,
567        log_image_interval=log_image_interval,
568        logger=logger,
569        logger_kwargs=logger_kwargs,
570        id_=id_,
571        save_root=save_root,
572        compile_model=compile_model,
573        rank=rank,
574    )
def default_segmentation_loader( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], batch_size: int, patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.device = torch.float32, label_dtype: torch.device = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None, pre_label_transform: Optional[Callable] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
223def default_segmentation_loader(
224    raw_paths: Union[List[Any], str, os.PathLike],
225    raw_key: Optional[str],
226    label_paths: Union[List[Any], str, os.PathLike],
227    label_key: Optional[str],
228    batch_size: int,
229    patch_shape: Tuple[int, ...],
230    label_transform: Optional[Callable] = None,
231    label_transform2: Optional[Callable] = None,
232    raw_transform: Optional[Callable] = None,
233    transform: Optional[Callable] = None,
234    dtype: torch.device = torch.float32,
235    label_dtype: torch.device = torch.float32,
236    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
237    n_samples: Optional[int] = None,
238    sampler: Optional[Callable] = None,
239    ndim: Optional[int] = None,
240    is_seg_dataset: Optional[bool] = None,
241    with_channels: bool = False,
242    with_label_channels: bool = False,
243    verify_paths: bool = True,
244    with_padding: bool = True,
245    z_ext: Optional[int] = None,
246    pre_label_transform: Optional[Callable] = None,
247    **loader_kwargs,
248) -> torch.utils.data.DataLoader:
249    """Get data loader for training a segmentation network.
250
251    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
252    on the data formats that are supported.
253
254    Args:
255        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
256            This argument also accepts a list of numpy arrays or torch tensors.
257        raw_key: The name of the internal dataset containing the raw data.
258            Set to None for regular image files, numpy arrays, or torch tensors.
259        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
260            This argument also accepts a list of numpy arrays or torch tensors.
261        label_key: The name of the internal dataset containing the raw data.
262            Set to None for regular image files, numpy arrays, or torch tensors.
263        batch_size: The batch size for the data loader.
264        patch_shape: The patch shape for the training samples.
265        label_transform: Transformation applied to the label data of a sample,
266            before applying augmentations via `transform`.
267        label_transform2: Transformation applied to the label data of a sample,
268            after applying augmentations via `transform`.
269        raw_transform: Transformation applied to the raw data of a sample,
270            before applying augmentations via `transform`.
271        transform: Transformation applied to both the raw data and label data of a sample.
272            This can be used to implement data augmentations.
273        dtype: The return data type of the raw data.
274        label_dtype: The return data type of the label data.
275        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
276        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
277        sampler: Sampler for rejecting samples according to a defined criterion.
278            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
279        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
280        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
281            If None, the type of dataset will be derived from the data.
282        with_channels: Whether the raw data has channels.
283        with_label_channels: Whether the label data has channels.
284        verify_paths: Whether to verify all paths before creating the dataset.
285        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
286        z_ext: Extra bounding box for loading the data across z.
287        pre_label_transform: Transformation applied to the label data of a chosen random sample,
288            before applying the sample validity via the `sampler`.
289        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
290
291    Returns:
292        The torch data loader.
293    """
294    ds = default_segmentation_dataset(
295        raw_paths=raw_paths,
296        raw_key=raw_key,
297        label_paths=label_paths,
298        label_key=label_key,
299        patch_shape=patch_shape,
300        label_transform=label_transform,
301        label_transform2=label_transform2,
302        raw_transform=raw_transform,
303        transform=transform,
304        dtype=dtype,
305        label_dtype=label_dtype,
306        rois=rois,
307        n_samples=n_samples,
308        sampler=sampler,
309        ndim=ndim,
310        is_seg_dataset=is_seg_dataset,
311        with_channels=with_channels,
312        with_label_channels=with_label_channels,
313        with_padding=with_padding,
314        z_ext=z_ext,
315        verify_paths=verify_paths,
316        pre_label_transform=pre_label_transform,
317    )
318    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get data loader for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the underlying dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • pre_label_transform: Transformation applied to the label data of a chosen random sample, before applying the sample validity via the sampler.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch data loader.

def default_segmentation_dataset( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None, pre_label_transform: Optional[Callable] = None) -> torch.utils.data.dataset.Dataset:
321def default_segmentation_dataset(
322    raw_paths: Union[List[Any], str, os.PathLike],
323    raw_key: Optional[str],
324    label_paths: Union[List[Any], str, os.PathLike],
325    label_key: Optional[str],
326    patch_shape: Tuple[int, ...],
327    label_transform: Optional[Callable] = None,
328    label_transform2: Optional[Callable] = None,
329    raw_transform: Optional[Callable] = None,
330    transform: Optional[Callable] = None,
331    dtype: torch.dtype = torch.float32,
332    label_dtype: torch.dtype = torch.float32,
333    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
334    n_samples: Optional[int] = None,
335    sampler: Optional[Callable] = None,
336    ndim: Optional[int] = None,
337    is_seg_dataset: Optional[bool] = None,
338    with_channels: bool = False,
339    with_label_channels: bool = False,
340    verify_paths: bool = True,
341    with_padding: bool = True,
342    z_ext: Optional[int] = None,
343    pre_label_transform: Optional[Callable] = None,
344) -> torch.utils.data.Dataset:
345    """Get data set for training a segmentation network.
346
347    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
348    on the data formats that are supported.
349
350    Args:
351        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
352            This argument also accepts a list of numpy arrays or torch tensors.
353        raw_key: The name of the internal dataset containing the raw data.
354            Set to None for regular image files, numpy arrays, or torch tensors.
355        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
356            This argument also accepts a list of numpy arrays or torch tensors.
357        label_key: The name of the internal dataset containing the raw data.
358            Set to None for regular image files, numpy arrays, or torch tensors.
359        patch_shape: The patch shape for the training samples.
360        label_transform: Transformation applied to the label data of a sample,
361            before applying augmentations via `transform`.
362        label_transform2: Transformation applied to the label data of a sample,
363            after applying augmentations via `transform`.
364        raw_transform: Transformation applied to the raw data of a sample,
365            before applying augmentations via `transform`.
366        transform: Transformation applied to both the raw data and label data of a sample.
367            This can be used to implement data augmentations.
368        dtype: The return data type of the raw data.
369        label_dtype: The return data type of the label data.
370        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
371        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
372        sampler: Sampler for rejecting samples according to a defined criterion.
373            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
374        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
375        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
376            If None, the type of dataset will be derived from the data.
377        with_channels: Whether the raw data has channels.
378        with_label_channels: Whether the label data has channels.
379        verify_paths: Whether to verify all paths before creating the dataset.
380        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
381        z_ext: Extra bounding box for loading the data across z.
382        pre_label_transform: Transformation applied to the label data of a chosen random sample,
383            before applying the sample validity via the `sampler`.
384
385    Returns:
386        The torch dataset.
387    """
388    if verify_paths:
389        check_paths(raw_paths, label_paths)
390
391    if is_seg_dataset is None:
392        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
393
394    # We always use a raw transform in the convenience function.
395    if raw_transform is None:
396        raw_transform = get_raw_transform()
397
398    # We always use augmentations in the convenience function.
399    if transform is None:
400        transform = _get_default_transform(
401            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
402        )
403
404    if is_seg_dataset:
405        ds = _load_segmentation_dataset(
406            raw_paths,
407            raw_key,
408            label_paths,
409            label_key,
410            patch_shape=patch_shape,
411            raw_transform=raw_transform,
412            label_transform=label_transform,
413            label_transform2=label_transform2,
414            transform=transform,
415            rois=rois,
416            n_samples=n_samples,
417            sampler=sampler,
418            ndim=ndim,
419            dtype=dtype,
420            label_dtype=label_dtype,
421            with_channels=with_channels,
422            with_label_channels=with_label_channels,
423            with_padding=with_padding,
424            z_ext=z_ext,
425            pre_label_transform=pre_label_transform,
426        )
427
428    else:
429        ds = _load_image_collection_dataset(
430            raw_paths,
431            raw_key,
432            label_paths,
433            label_key,
434            roi=rois,
435            patch_shape=patch_shape,
436            label_transform=label_transform,
437            raw_transform=raw_transform,
438            label_transform2=label_transform2,
439            transform=transform,
440            n_samples=n_samples,
441            sampler=sampler,
442            dtype=dtype,
443            label_dtype=label_dtype,
444            with_padding=with_padding,
445            with_channels=with_channels,
446            pre_label_transform=pre_label_transform,
447        )
448
449    return ds

Get data set for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • pre_label_transform: Transformation applied to the label data of a chosen random sample, before applying the sample validity via the sampler.
Returns:

The torch dataset.

def default_segmentation_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Optional[torch.nn.modules.module.Module] = None, metric: Optional[Callable] = None, learning_rate: float = 0.001, device: Union[str, torch.device, NoneType] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs: Dict[str, Any] = {'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs: Dict[str, Any] = {}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None):
467def default_segmentation_trainer(
468    name: str,
469    model: torch.nn.Module,
470    train_loader: DataLoader,
471    val_loader: DataLoader,
472    loss: Optional[torch.nn.Module] = None,
473    metric: Optional[Callable] = None,
474    learning_rate: float = 1e-3,
475    device: Optional[Union[str, torch.device]] = None,
476    log_image_interval: int = 100,
477    mixed_precision: bool = True,
478    early_stopping: Optional[int] = None,
479    logger=TensorboardLogger,
480    logger_kwargs: Optional[Dict[str, Any]] = None,
481    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
482    optimizer_kwargs: Dict[str, Any] = {},
483    trainer_class=DefaultTrainer,
484    id_: Optional[str] = None,
485    save_root: Optional[str] = None,
486    compile_model: Optional[Union[bool, str]] = None,
487    rank: Optional[int] = None,
488):
489    """Get a trainer for a segmentation network.
490
491    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
492    By default, it uses the dice score as loss and metric.
493    This can be changed by passing arguments for `loss` and/or `metric`.
494    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
495
496    Here's an example for training a 2D U-Net with this function:
497    ```python
498    import torch_em
499    from torch_em.model import UNet2d
500    from torch_em.data.datasets.light_microscopy import get_dsb_loader
501
502    # The training data will be downloaded to this location.
503    data_root = "/path/to/save/the/training/data"
504    patch_shape = (256, 256)
505    trainer = default_segmentation_trainer(
506        name="unet-training"
507        model=UNet2d(in_channels=1, out_channels=1)
508        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
509        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
510    )
511    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
512    ```
513
514    Args:
515        name: The name of the checkpoint that will be created by the trainer.
516        model: The model to train.
517        train_loader: The data loader containing the training data.
518        val_loader: The data loader containing the validation data.
519        loss: The loss function for training.
520        metric: The metric for validation.
521        learning_rate: The initial learning rate for the AdamW optimizer.
522        device: The torch device to use for training. If None, will use a GPU if available.
523        log_image_interval: The interval for saving images during logging, in training iterations.
524        mixed_precision: Whether to train with mixed precision.
525        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
526        logger: The logger class. Will be instantiated for logging.
527            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
528        logger_kwargs: The keyword arguments for the logger class.
529        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
530        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
531        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
532            but can be set to a custom trainer class to enable custom training procedures.
533        id_: Unique identifier for the trainer. If None then `name` will be used.
534        save_root: The root folder for saving the checkpoint and logs.
535        compile_model: Whether to compile the model before training.
536        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
537
538    Returns:
539        The trainer.
540    """
541    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
542    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
543
544    loss = DiceLoss() if loss is None else loss
545    metric = DiceLoss() if metric is None else metric
546
547    if device is None:
548        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
549    else:
550        device = torch.device(device)
551
552    # CPU does not support mixed precision training.
553    if device.type == "cpu":
554        mixed_precision = False
555
556    return trainer_class(
557        name=name,
558        model=model,
559        train_loader=train_loader,
560        val_loader=val_loader,
561        loss=loss,
562        metric=metric,
563        optimizer=optimizer,
564        device=device,
565        lr_scheduler=scheduler,
566        mixed_precision=mixed_precision,
567        early_stopping=early_stopping,
568        log_image_interval=log_image_interval,
569        logger=logger,
570        logger_kwargs=logger_kwargs,
571        id_=id_,
572        save_root=save_root,
573        compile_model=compile_model,
574        rank=rank,
575    )

Get a trainer for a segmentation network.

It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau. By default, it uses the dice score as loss and metric. This can be changed by passing arguments for loss and/or metric. See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.

Here's an example for training a 2D U-Net with this function:

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
    name="unet-training"
    model=UNet2d(in_channels=1, out_channels=1)
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • model: The model to train.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • loss: The loss function for training.
  • metric: The metric for validation.
  • learning_rate: The initial learning rate for the AdamW optimizer.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
  • optimizer_kwargs: The keyword arguments for the AdamW optimizer.
  • trainer_class: The trainer class. Uses torch_em.trainer.DefaultTrainer by default, but can be set to a custom trainer class to enable custom training procedures.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
Returns:

The trainer.