torch_em.segmentation

  1import os
  2from glob import glob
  3from typing import Any, Dict, Optional, Union, Tuple, List, Callable
  4
  5import numpy as np
  6import torch
  7import torch.utils.data
  8from torch.utils.data import DataLoader
  9
 10from .loss import DiceLoss
 11from .util import load_data
 12from .trainer import DefaultTrainer
 13from .trainer.tensorboard_logger import TensorboardLogger
 14from .transform import get_augmentations, get_raw_transform
 15from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset, TensorDataset
 16
 17
 18# TODO add a heuristic to estimate this from the number of epochs
 19DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5}
 20"""@private
 21"""
 22
 23
 24#
 25# convenience functions for segmentation loaders
 26#
 27
 28# TODO implement balanced and make it the default
 29# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"):
 30def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
 31    """@private
 32    """
 33    assert split in ("balanced", "uniform")
 34    n_datasets = len(raw_paths)
 35    if split == "uniform":
 36        # even distribution of samples to datasets
 37        samples_per_ds = n_samples // n_datasets
 38        divider = n_samples % n_datasets
 39        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
 40    else:
 41        # distribution of samples to dataset based on the dataset lens
 42        raise NotImplementedError
 43
 44
 45def check_paths(raw_paths, label_paths):
 46    """@private
 47    """
 48    if not isinstance(raw_paths, type(label_paths)):
 49        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
 50
 51    # This is a tensor dataset and we don't need to verify the paths.
 52    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (torch.Tensor, np.ndarray)):
 53        return
 54
 55    def _check_path(path):
 56        if isinstance(path, str):
 57            if not os.path.exists(path):
 58                raise ValueError(f"Could not find path {path}")
 59        else:
 60            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
 61            for per_path in path:
 62                if not os.path.exists(per_path):
 63                    raise ValueError(f"Could not find path {per_path}")
 64
 65    if isinstance(raw_paths, str):
 66        _check_path(raw_paths)
 67        _check_path(label_paths)
 68    else:
 69        if len(raw_paths) != len(label_paths):
 70            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
 71        for rp, lp in zip(raw_paths, label_paths):
 72            _check_path(rp)
 73            _check_path(lp)
 74
 75
 76# Check if we can load the data as SegmentationDataset.
 77def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
 78    """@private
 79    """
 80    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)):
 81        if not all(isinstance(rp, (np.ndarray, torch.Tensor)) for rp in raw_paths):
 82            raise ValueError("Inconsistent raw data")
 83        if not all(isinstance(lp, (np.ndarray, torch.Tensor)) for lp in label_paths):
 84            raise ValueError("Inconsistent label data")
 85        return False
 86
 87    def _can_open(path, key):
 88        try:
 89            load_data(path, key)
 90            return True
 91        except Exception:
 92            return False
 93
 94    if isinstance(raw_paths, str):
 95        can_open_raw = _can_open(raw_paths, raw_key)
 96        can_open_label = _can_open(label_paths, label_key)
 97    else:
 98        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
 99        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
100            raise ValueError("Inconsistent raw data")
101        can_open_raw = can_open_raw[0]
102
103        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
104        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
105            raise ValueError("Inconsistent label data")
106        can_open_label = can_open_label[0]
107
108    if can_open_raw != can_open_label:
109        raise ValueError("Inconsistent raw and label data")
110
111    return can_open_raw
112
113
114def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs):
115    rois = kwargs.pop("rois", None)
116    if isinstance(raw_paths, str):
117        if rois is not None:
118            assert isinstance(rois, (tuple, slice))
119            if isinstance(rois, tuple):
120                assert all(isinstance(roi, slice) for roi in rois)
121        ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
122    else:
123        assert len(raw_paths) > 0
124        if rois is not None:
125            assert len(rois) == len(label_paths)
126            assert all(isinstance(roi, tuple) for roi in rois), f"{rois}"
127        n_samples = kwargs.pop("n_samples", None)
128
129        samples_per_ds = (
130            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
131        )
132        ds = []
133        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
134            roi = None if rois is None else rois[i]
135            dset = SegmentationDataset(
136                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs
137            )
138            ds.append(dset)
139        ds = ConcatDataset(*ds)
140    return ds
141
142
143def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, with_channels, **kwargs):
144    if isinstance(raw_paths[0], (torch.Tensor, np.ndarray)):
145        assert raw_key is None and label_key is None
146        assert roi is None
147        return TensorDataset(raw_paths, label_paths, with_channels=with_channels, **kwargs)
148
149    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
150        rpath = glob(os.path.join(rpath, rkey))
151        rpath.sort()
152        if len(rpath) == 0:
153            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
154
155        lpath = glob(os.path.join(lpath, lkey))
156        lpath.sort()
157        if len(rpath) != len(lpath):
158            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
159
160        if this_roi is not None:
161            rpath, lpath = rpath[roi], lpath[roi]
162
163        return rpath, lpath
164
165    patch_shape = kwargs.pop("patch_shape")
166    if patch_shape is not None:
167        if len(patch_shape) == 3:
168            if patch_shape[0] != 1:
169                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
170            patch_shape = patch_shape[1:]
171        assert len(patch_shape) == 2
172
173    if isinstance(raw_paths, str):
174        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
175        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
176
177    elif raw_key is None:
178        assert label_key is None
179        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
180        assert len(raw_paths) == len(label_paths)
181        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
182
183    else:
184        ds = []
185        n_samples = kwargs.pop("n_samples", None)
186        samples_per_ds = (
187            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
188        )
189        if roi is None:
190            roi = len(raw_paths) * [None]
191        assert len(roi) == len(raw_paths)
192        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
193            print(raw_path, label_path, this_roi)
194            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
195            dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs)
196            ds.append(dset)
197        ds = ConcatDataset(*ds)
198
199    return ds
200
201
202def _get_default_transform(path, key, is_seg_dataset, ndim):
203    if is_seg_dataset and ndim is None:
204        shape = load_data(path, key).shape
205        if len(shape) == 2:
206            ndim = 2
207        else:
208            # heuristics to figure out whether to use default 3d
209            # or default anisotropic augmentations
210            ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
211
212    elif is_seg_dataset and ndim is not None:
213        pass
214
215    else:
216        ndim = 2
217
218    return get_augmentations(ndim)
219
220
221def default_segmentation_loader(
222    raw_paths: Union[List[Any], str, os.PathLike],
223    raw_key: Optional[str],
224    label_paths: Union[List[Any], str, os.PathLike],
225    label_key: Optional[str],
226    batch_size: int,
227    patch_shape: Tuple[int, ...],
228    label_transform: Optional[Callable] = None,
229    label_transform2: Optional[Callable] = None,
230    raw_transform: Optional[Callable] = None,
231    transform: Optional[Callable] = None,
232    dtype: torch.device = torch.float32,
233    label_dtype: torch.device = torch.float32,
234    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
235    n_samples: Optional[int] = None,
236    sampler: Optional[Callable] = None,
237    ndim: Optional[int] = None,
238    is_seg_dataset: Optional[bool] = None,
239    with_channels: bool = False,
240    with_label_channels: bool = False,
241    verify_paths: bool = True,
242    with_padding: bool = True,
243    z_ext: Optional[int] = None,
244    **loader_kwargs,
245) -> torch.utils.data.DataLoader:
246    """Get data loader for training a segmentation network.
247
248    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
249    on the data formats that are supported.
250
251    Args:
252        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
253            This argument also accepts a list of numpy arrays or torch tensors.
254        raw_key: The name of the internal dataset containing the raw data.
255            Set to None for regular image files, numpy arrays, or torch tensors.
256        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
257            This argument also accepts a list of numpy arrays or torch tensors.
258        label_key: The name of the internal dataset containing the raw data.
259            Set to None for regular image files, numpy arrays, or torch tensors.
260        batch_size: The batch size for the data loader.
261        patch_shape: The patch shape for the training samples.
262        label_transform: Transformation applied to the label data of a sample,
263            before applying augmentations via `transform`.
264        label_transform2: Transformation applied to the label data of a sample,
265            after applying augmentations via `transform`.
266        raw_transform: Transformation applied to the raw data of a sample,
267            before applying augmentations via `transform`.
268        transform: Transformation applied to both the raw data and label data of a sample.
269            This can be used to implement data augmentations.
270        dtype: The return data type of the raw data.
271        label_dtype: The return data type of the label data.
272        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
273        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
274        sampler: Sampler for rejecting samples according to a defined criterion.
275            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
276        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
277        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
278            If None, the type of dataset will be derived from the data.
279        with_channels: Whether the raw data has channels.
280        with_label_channels: Whether the label data has channels.
281        verify_paths: Whether to verify all paths before creating the dataset.
282        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
283        z_ext: Extra bounding box for loading the data across z.
284        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
285
286    Returns:
287        The torch data loader.
288    """
289    ds = default_segmentation_dataset(
290        raw_paths=raw_paths,
291        raw_key=raw_key,
292        label_paths=label_paths,
293        label_key=label_key,
294        patch_shape=patch_shape,
295        label_transform=label_transform,
296        label_transform2=label_transform2,
297        raw_transform=raw_transform,
298        transform=transform,
299        dtype=dtype,
300        label_dtype=label_dtype,
301        rois=rois,
302        n_samples=n_samples,
303        sampler=sampler,
304        ndim=ndim,
305        is_seg_dataset=is_seg_dataset,
306        with_channels=with_channels,
307        with_label_channels=with_label_channels,
308        with_padding=with_padding,
309        z_ext=z_ext,
310        verify_paths=verify_paths,
311    )
312    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
313
314
315def default_segmentation_dataset(
316    raw_paths: Union[List[Any], str, os.PathLike],
317    raw_key: Optional[str],
318    label_paths: Union[List[Any], str, os.PathLike],
319    label_key: Optional[str],
320    patch_shape: Tuple[int, ...],
321    label_transform: Optional[Callable] = None,
322    label_transform2: Optional[Callable] = None,
323    raw_transform: Optional[Callable] = None,
324    transform: Optional[Callable] = None,
325    dtype: torch.dtype = torch.float32,
326    label_dtype: torch.dtype = torch.float32,
327    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
328    n_samples: Optional[int] = None,
329    sampler: Optional[Callable] = None,
330    ndim: Optional[int] = None,
331    is_seg_dataset: Optional[bool] = None,
332    with_channels: bool = False,
333    with_label_channels: bool = False,
334    verify_paths: bool = True,
335    with_padding: bool = True,
336    z_ext: Optional[int] = None,
337) -> torch.utils.data.Dataset:
338    """Get data set for training a segmentation network.
339
340    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
341    on the data formats that are supported.
342
343    Args:
344        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
345            This argument also accepts a list of numpy arrays or torch tensors.
346        raw_key: The name of the internal dataset containing the raw data.
347            Set to None for regular image files, numpy arrays, or torch tensors.
348        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
349            This argument also accepts a list of numpy arrays or torch tensors.
350        label_key: The name of the internal dataset containing the raw data.
351            Set to None for regular image files, numpy arrays, or torch tensors.
352        patch_shape: The patch shape for the training samples.
353        label_transform: Transformation applied to the label data of a sample,
354            before applying augmentations via `transform`.
355        label_transform2: Transformation applied to the label data of a sample,
356            after applying augmentations via `transform`.
357        raw_transform: Transformation applied to the raw data of a sample,
358            before applying augmentations via `transform`.
359        transform: Transformation applied to both the raw data and label data of a sample.
360            This can be used to implement data augmentations.
361        dtype: The return data type of the raw data.
362        label_dtype: The return data type of the label data.
363        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
364        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
365        sampler: Sampler for rejecting samples according to a defined criterion.
366            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
367        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
368        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
369            If None, the type of dataset will be derived from the data.
370        with_channels: Whether the raw data has channels.
371        with_label_channels: Whether the label data has channels.
372        verify_paths: Whether to verify all paths before creating the dataset.
373        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
374        z_ext: Extra bounding box for loading the data across z.
375        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
376
377    Returns:
378        The torch dataset.
379    """
380    if verify_paths:
381        check_paths(raw_paths, label_paths)
382
383    if is_seg_dataset is None:
384        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
385
386    # We always use a raw transform in the convenience function.
387    if raw_transform is None:
388        raw_transform = get_raw_transform()
389
390    # We always use augmentations in the convenience function.
391    if transform is None:
392        transform = _get_default_transform(
393            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
394        )
395
396    if is_seg_dataset:
397        ds = _load_segmentation_dataset(
398            raw_paths,
399            raw_key,
400            label_paths,
401            label_key,
402            patch_shape=patch_shape,
403            raw_transform=raw_transform,
404            label_transform=label_transform,
405            label_transform2=label_transform2,
406            transform=transform,
407            rois=rois,
408            n_samples=n_samples,
409            sampler=sampler,
410            ndim=ndim,
411            dtype=dtype,
412            label_dtype=label_dtype,
413            with_channels=with_channels,
414            with_label_channels=with_label_channels,
415            with_padding=with_padding,
416            z_ext=z_ext,
417        )
418
419    else:
420        ds = _load_image_collection_dataset(
421            raw_paths,
422            raw_key,
423            label_paths,
424            label_key,
425            roi=rois,
426            patch_shape=patch_shape,
427            label_transform=label_transform,
428            raw_transform=raw_transform,
429            label_transform2=label_transform2,
430            transform=transform,
431            n_samples=n_samples,
432            sampler=sampler,
433            dtype=dtype,
434            label_dtype=label_dtype,
435            with_padding=with_padding,
436            with_channels=with_channels,
437        )
438
439    return ds
440
441
442def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader:
443    """@private
444    """
445    pin_memory = loader_kwargs.pop("pin_memory", True)
446    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs)
447    # monkey patch shuffle attribute to the loader
448    loader.shuffle = loader_kwargs.get("shuffle", False)
449    return loader
450
451
452#
453# convenience functions for segmentation trainers
454#
455
456
457def default_segmentation_trainer(
458    name: str,
459    model: torch.nn.Module,
460    train_loader: DataLoader,
461    val_loader: DataLoader,
462    loss: Optional[torch.nn.Module] = None,
463    metric: Optional[Callable] = None,
464    learning_rate: float = 1e-3,
465    device: Optional[Union[str, torch.device]] = None,
466    log_image_interval: int = 100,
467    mixed_precision: bool = True,
468    early_stopping: Optional[int] = None,
469    logger=TensorboardLogger,
470    logger_kwargs: Optional[Dict[str, Any]] = None,
471    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
472    optimizer_kwargs: Dict[str, Any] = {},
473    trainer_class=DefaultTrainer,
474    id_: Optional[str] = None,
475    save_root: Optional[str] = None,
476    compile_model: Optional[Union[bool, str]] = None,
477    rank: Optional[int] = None,
478):
479    """Get a trainer for a segmentation network.
480
481    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
482    By default, it uses the dice score as loss and metric.
483    This can be changed by passing arguments for `loss` and/or `metric`.
484    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
485
486    Here's an example for training a 2D U-Net with this function:
487    ```python
488    import torch_em
489    from torch_em.model import UNet2d
490    from torch_em.data.datasets.light_microscopy import get_dsb_loader
491
492    # The training data will be downloaded to this location.
493    data_root = "/path/to/save/the/training/data"
494    patch_shape = (256, 256)
495    trainer = default_segmentation_trainer(
496        name="unet-training"
497        model=UNet2d(in_channels=1, out_channels=1)
498        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
499        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
500    )
501    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
502    ```
503
504    Args:
505        name: The name of the checkpoint that will be created by the trainer.
506        model: The model to train.
507        train_loader: The data loader containing the training data.
508        val_loader: The data loader containing the validation data.
509        loss: The loss function for training.
510        metric: The metric for validation.
511        learning_rate: The initial learning rate for the AdamW optimizer.
512        device: The torch device to use for training. If None, will use a GPU if available.
513        log_image_interval: The interval for saving images during logging, in training iterations.
514        mixed_precision: Whether to train with mixed precision.
515        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
516        logger: The logger class. Will be instantiated for logging.
517            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
518        logger_kwargs: The keyword arguments for the logger class.
519        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
520        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
521        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
522            but can be set to a custom trainer class to enable custom training procedures.
523        id_: Unique identifier for the trainer. If None then `name` will be used.
524        save_root: The root folder for saving the checkpoint and logs.
525        compile_model: Whether to compile the model before training.
526        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
527
528    Returns:
529        The trainer.
530    """
531    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
532    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
533
534    loss = DiceLoss() if loss is None else loss
535    metric = DiceLoss() if metric is None else metric
536
537    if device is None:
538        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
539    else:
540        device = torch.device(device)
541
542    # CPU does not support mixed precision training.
543    if device.type == "cpu":
544        mixed_precision = False
545
546    return trainer_class(
547        name=name,
548        model=model,
549        train_loader=train_loader,
550        val_loader=val_loader,
551        loss=loss,
552        metric=metric,
553        optimizer=optimizer,
554        device=device,
555        lr_scheduler=scheduler,
556        mixed_precision=mixed_precision,
557        early_stopping=early_stopping,
558        log_image_interval=log_image_interval,
559        logger=logger,
560        logger_kwargs=logger_kwargs,
561        id_=id_,
562        save_root=save_root,
563        compile_model=compile_model,
564        rank=rank,
565    )
def default_segmentation_loader( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], batch_size: int, patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.device = torch.float32, label_dtype: torch.device = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
222def default_segmentation_loader(
223    raw_paths: Union[List[Any], str, os.PathLike],
224    raw_key: Optional[str],
225    label_paths: Union[List[Any], str, os.PathLike],
226    label_key: Optional[str],
227    batch_size: int,
228    patch_shape: Tuple[int, ...],
229    label_transform: Optional[Callable] = None,
230    label_transform2: Optional[Callable] = None,
231    raw_transform: Optional[Callable] = None,
232    transform: Optional[Callable] = None,
233    dtype: torch.device = torch.float32,
234    label_dtype: torch.device = torch.float32,
235    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
236    n_samples: Optional[int] = None,
237    sampler: Optional[Callable] = None,
238    ndim: Optional[int] = None,
239    is_seg_dataset: Optional[bool] = None,
240    with_channels: bool = False,
241    with_label_channels: bool = False,
242    verify_paths: bool = True,
243    with_padding: bool = True,
244    z_ext: Optional[int] = None,
245    **loader_kwargs,
246) -> torch.utils.data.DataLoader:
247    """Get data loader for training a segmentation network.
248
249    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
250    on the data formats that are supported.
251
252    Args:
253        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
254            This argument also accepts a list of numpy arrays or torch tensors.
255        raw_key: The name of the internal dataset containing the raw data.
256            Set to None for regular image files, numpy arrays, or torch tensors.
257        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
258            This argument also accepts a list of numpy arrays or torch tensors.
259        label_key: The name of the internal dataset containing the raw data.
260            Set to None for regular image files, numpy arrays, or torch tensors.
261        batch_size: The batch size for the data loader.
262        patch_shape: The patch shape for the training samples.
263        label_transform: Transformation applied to the label data of a sample,
264            before applying augmentations via `transform`.
265        label_transform2: Transformation applied to the label data of a sample,
266            after applying augmentations via `transform`.
267        raw_transform: Transformation applied to the raw data of a sample,
268            before applying augmentations via `transform`.
269        transform: Transformation applied to both the raw data and label data of a sample.
270            This can be used to implement data augmentations.
271        dtype: The return data type of the raw data.
272        label_dtype: The return data type of the label data.
273        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
274        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
275        sampler: Sampler for rejecting samples according to a defined criterion.
276            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
277        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
278        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
279            If None, the type of dataset will be derived from the data.
280        with_channels: Whether the raw data has channels.
281        with_label_channels: Whether the label data has channels.
282        verify_paths: Whether to verify all paths before creating the dataset.
283        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
284        z_ext: Extra bounding box for loading the data across z.
285        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
286
287    Returns:
288        The torch data loader.
289    """
290    ds = default_segmentation_dataset(
291        raw_paths=raw_paths,
292        raw_key=raw_key,
293        label_paths=label_paths,
294        label_key=label_key,
295        patch_shape=patch_shape,
296        label_transform=label_transform,
297        label_transform2=label_transform2,
298        raw_transform=raw_transform,
299        transform=transform,
300        dtype=dtype,
301        label_dtype=label_dtype,
302        rois=rois,
303        n_samples=n_samples,
304        sampler=sampler,
305        ndim=ndim,
306        is_seg_dataset=is_seg_dataset,
307        with_channels=with_channels,
308        with_label_channels=with_label_channels,
309        with_padding=with_padding,
310        z_ext=z_ext,
311        verify_paths=verify_paths,
312    )
313    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get data loader for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the underlying dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch data loader.

def default_segmentation_dataset( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None) -> torch.utils.data.dataset.Dataset:
316def default_segmentation_dataset(
317    raw_paths: Union[List[Any], str, os.PathLike],
318    raw_key: Optional[str],
319    label_paths: Union[List[Any], str, os.PathLike],
320    label_key: Optional[str],
321    patch_shape: Tuple[int, ...],
322    label_transform: Optional[Callable] = None,
323    label_transform2: Optional[Callable] = None,
324    raw_transform: Optional[Callable] = None,
325    transform: Optional[Callable] = None,
326    dtype: torch.dtype = torch.float32,
327    label_dtype: torch.dtype = torch.float32,
328    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
329    n_samples: Optional[int] = None,
330    sampler: Optional[Callable] = None,
331    ndim: Optional[int] = None,
332    is_seg_dataset: Optional[bool] = None,
333    with_channels: bool = False,
334    with_label_channels: bool = False,
335    verify_paths: bool = True,
336    with_padding: bool = True,
337    z_ext: Optional[int] = None,
338) -> torch.utils.data.Dataset:
339    """Get data set for training a segmentation network.
340
341    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
342    on the data formats that are supported.
343
344    Args:
345        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
346            This argument also accepts a list of numpy arrays or torch tensors.
347        raw_key: The name of the internal dataset containing the raw data.
348            Set to None for regular image files, numpy arrays, or torch tensors.
349        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
350            This argument also accepts a list of numpy arrays or torch tensors.
351        label_key: The name of the internal dataset containing the raw data.
352            Set to None for regular image files, numpy arrays, or torch tensors.
353        patch_shape: The patch shape for the training samples.
354        label_transform: Transformation applied to the label data of a sample,
355            before applying augmentations via `transform`.
356        label_transform2: Transformation applied to the label data of a sample,
357            after applying augmentations via `transform`.
358        raw_transform: Transformation applied to the raw data of a sample,
359            before applying augmentations via `transform`.
360        transform: Transformation applied to both the raw data and label data of a sample.
361            This can be used to implement data augmentations.
362        dtype: The return data type of the raw data.
363        label_dtype: The return data type of the label data.
364        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
365        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
366        sampler: Sampler for rejecting samples according to a defined criterion.
367            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
368        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
369        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
370            If None, the type of dataset will be derived from the data.
371        with_channels: Whether the raw data has channels.
372        with_label_channels: Whether the label data has channels.
373        verify_paths: Whether to verify all paths before creating the dataset.
374        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
375        z_ext: Extra bounding box for loading the data across z.
376        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
377
378    Returns:
379        The torch dataset.
380    """
381    if verify_paths:
382        check_paths(raw_paths, label_paths)
383
384    if is_seg_dataset is None:
385        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
386
387    # We always use a raw transform in the convenience function.
388    if raw_transform is None:
389        raw_transform = get_raw_transform()
390
391    # We always use augmentations in the convenience function.
392    if transform is None:
393        transform = _get_default_transform(
394            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
395        )
396
397    if is_seg_dataset:
398        ds = _load_segmentation_dataset(
399            raw_paths,
400            raw_key,
401            label_paths,
402            label_key,
403            patch_shape=patch_shape,
404            raw_transform=raw_transform,
405            label_transform=label_transform,
406            label_transform2=label_transform2,
407            transform=transform,
408            rois=rois,
409            n_samples=n_samples,
410            sampler=sampler,
411            ndim=ndim,
412            dtype=dtype,
413            label_dtype=label_dtype,
414            with_channels=with_channels,
415            with_label_channels=with_label_channels,
416            with_padding=with_padding,
417            z_ext=z_ext,
418        )
419
420    else:
421        ds = _load_image_collection_dataset(
422            raw_paths,
423            raw_key,
424            label_paths,
425            label_key,
426            roi=rois,
427            patch_shape=patch_shape,
428            label_transform=label_transform,
429            raw_transform=raw_transform,
430            label_transform2=label_transform2,
431            transform=transform,
432            n_samples=n_samples,
433            sampler=sampler,
434            dtype=dtype,
435            label_dtype=label_dtype,
436            with_padding=with_padding,
437            with_channels=with_channels,
438        )
439
440    return ds

Get data set for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch dataset.

def default_segmentation_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Optional[torch.nn.modules.module.Module] = None, metric: Optional[Callable] = None, learning_rate: float = 0.001, device: Union[str, torch.device, NoneType] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs: Dict[str, Any] = {'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs: Dict[str, Any] = {}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None):
458def default_segmentation_trainer(
459    name: str,
460    model: torch.nn.Module,
461    train_loader: DataLoader,
462    val_loader: DataLoader,
463    loss: Optional[torch.nn.Module] = None,
464    metric: Optional[Callable] = None,
465    learning_rate: float = 1e-3,
466    device: Optional[Union[str, torch.device]] = None,
467    log_image_interval: int = 100,
468    mixed_precision: bool = True,
469    early_stopping: Optional[int] = None,
470    logger=TensorboardLogger,
471    logger_kwargs: Optional[Dict[str, Any]] = None,
472    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
473    optimizer_kwargs: Dict[str, Any] = {},
474    trainer_class=DefaultTrainer,
475    id_: Optional[str] = None,
476    save_root: Optional[str] = None,
477    compile_model: Optional[Union[bool, str]] = None,
478    rank: Optional[int] = None,
479):
480    """Get a trainer for a segmentation network.
481
482    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
483    By default, it uses the dice score as loss and metric.
484    This can be changed by passing arguments for `loss` and/or `metric`.
485    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
486
487    Here's an example for training a 2D U-Net with this function:
488    ```python
489    import torch_em
490    from torch_em.model import UNet2d
491    from torch_em.data.datasets.light_microscopy import get_dsb_loader
492
493    # The training data will be downloaded to this location.
494    data_root = "/path/to/save/the/training/data"
495    patch_shape = (256, 256)
496    trainer = default_segmentation_trainer(
497        name="unet-training"
498        model=UNet2d(in_channels=1, out_channels=1)
499        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
500        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
501    )
502    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
503    ```
504
505    Args:
506        name: The name of the checkpoint that will be created by the trainer.
507        model: The model to train.
508        train_loader: The data loader containing the training data.
509        val_loader: The data loader containing the validation data.
510        loss: The loss function for training.
511        metric: The metric for validation.
512        learning_rate: The initial learning rate for the AdamW optimizer.
513        device: The torch device to use for training. If None, will use a GPU if available.
514        log_image_interval: The interval for saving images during logging, in training iterations.
515        mixed_precision: Whether to train with mixed precision.
516        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
517        logger: The logger class. Will be instantiated for logging.
518            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
519        logger_kwargs: The keyword arguments for the logger class.
520        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
521        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
522        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
523            but can be set to a custom trainer class to enable custom training procedures.
524        id_: Unique identifier for the trainer. If None then `name` will be used.
525        save_root: The root folder for saving the checkpoint and logs.
526        compile_model: Whether to compile the model before training.
527        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
528
529    Returns:
530        The trainer.
531    """
532    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
533    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
534
535    loss = DiceLoss() if loss is None else loss
536    metric = DiceLoss() if metric is None else metric
537
538    if device is None:
539        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
540    else:
541        device = torch.device(device)
542
543    # CPU does not support mixed precision training.
544    if device.type == "cpu":
545        mixed_precision = False
546
547    return trainer_class(
548        name=name,
549        model=model,
550        train_loader=train_loader,
551        val_loader=val_loader,
552        loss=loss,
553        metric=metric,
554        optimizer=optimizer,
555        device=device,
556        lr_scheduler=scheduler,
557        mixed_precision=mixed_precision,
558        early_stopping=early_stopping,
559        log_image_interval=log_image_interval,
560        logger=logger,
561        logger_kwargs=logger_kwargs,
562        id_=id_,
563        save_root=save_root,
564        compile_model=compile_model,
565        rank=rank,
566    )

Get a trainer for a segmentation network.

It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau. By default, it uses the dice score as loss and metric. This can be changed by passing arguments for loss and/or metric. See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.

Here's an example for training a 2D U-Net with this function:

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
    name="unet-training"
    model=UNet2d(in_channels=1, out_channels=1)
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • model: The model to train.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • loss: The loss function for training.
  • metric: The metric for validation.
  • learning_rate: The initial learning rate for the AdamW optimizer.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
  • optimizer_kwargs: The keyword arguments for the AdamW optimizer.
  • trainer_class: The trainer class. Uses torch_em.trainer.DefaultTrainer by default, but can be set to a custom trainer class to enable custom training procedures.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
Returns:

The trainer.