torch_em.segmentation

  1import os
  2from glob import glob
  3from typing import Any, Dict, Optional, Union, Tuple, List, Callable
  4
  5import torch
  6import torch.utils.data
  7from torch.utils.data import DataLoader
  8
  9from .loss import DiceLoss
 10from .util import load_data
 11from .trainer import DefaultTrainer
 12from .trainer.tensorboard_logger import TensorboardLogger
 13from .transform import get_augmentations, get_raw_transform
 14from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
 15
 16
 17# TODO add a heuristic to estimate this from the number of epochs
 18DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5}
 19"""@private
 20"""
 21
 22
 23#
 24# convenience functions for segmentation loaders
 25#
 26
 27# TODO implement balanced and make it the default
 28# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"):
 29def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
 30    """@private
 31    """
 32    assert split in ("balanced", "uniform")
 33    n_datasets = len(raw_paths)
 34    if split == "uniform":
 35        # even distribution of samples to datasets
 36        samples_per_ds = n_samples // n_datasets
 37        divider = n_samples % n_datasets
 38        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
 39    else:
 40        # distribution of samples to dataset based on the dataset lens
 41        raise NotImplementedError
 42
 43
 44def check_paths(raw_paths, label_paths):
 45    """@private
 46    """
 47    if not isinstance(raw_paths, type(label_paths)):
 48        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
 49
 50    def _check_path(path):
 51        if isinstance(path, str):
 52            if not os.path.exists(path):
 53                raise ValueError(f"Could not find path {path}")
 54        else:
 55            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
 56            for per_path in path:
 57                if not os.path.exists(per_path):
 58                    raise ValueError(f"Could not find path {per_path}")
 59
 60    if isinstance(raw_paths, str):
 61        _check_path(raw_paths)
 62        _check_path(label_paths)
 63    else:
 64        if len(raw_paths) != len(label_paths):
 65            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
 66        for rp, lp in zip(raw_paths, label_paths):
 67            _check_path(rp)
 68            _check_path(lp)
 69
 70
 71# Check if we can load the data as SegmentationDataset.
 72def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
 73    """@private
 74    """
 75    def _can_open(path, key):
 76        try:
 77            load_data(path, key)
 78            return True
 79        except Exception:
 80            return False
 81
 82    if isinstance(raw_paths, str):
 83        can_open_raw = _can_open(raw_paths, raw_key)
 84        can_open_label = _can_open(label_paths, label_key)
 85    else:
 86        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
 87        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
 88            raise ValueError("Inconsistent raw data")
 89        can_open_raw = can_open_raw[0]
 90
 91        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
 92        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
 93            raise ValueError("Inconsistent label data")
 94        can_open_label = can_open_label[0]
 95
 96    if can_open_raw != can_open_label:
 97        raise ValueError("Inconsistent raw and label data")
 98
 99    return can_open_raw
100
101
102def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs):
103    rois = kwargs.pop("rois", None)
104    if isinstance(raw_paths, str):
105        if rois is not None:
106            assert isinstance(rois, (tuple, slice))
107            if isinstance(rois, tuple):
108                assert all(isinstance(roi, slice) for roi in rois)
109        ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
110    else:
111        assert len(raw_paths) > 0
112        if rois is not None:
113            assert len(rois) == len(label_paths)
114            assert all(isinstance(roi, tuple) for roi in rois), f"{rois}"
115        n_samples = kwargs.pop("n_samples", None)
116
117        samples_per_ds = (
118            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
119        )
120        ds = []
121        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
122            roi = None if rois is None else rois[i]
123            dset = SegmentationDataset(
124                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs
125            )
126            ds.append(dset)
127        ds = ConcatDataset(*ds)
128    return ds
129
130
131def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs):
132    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
133        rpath = glob(os.path.join(rpath, rkey))
134        rpath.sort()
135        if len(rpath) == 0:
136            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
137
138        lpath = glob(os.path.join(lpath, lkey))
139        lpath.sort()
140        if len(rpath) != len(lpath):
141            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
142
143        if this_roi is not None:
144            rpath, lpath = rpath[roi], lpath[roi]
145
146        return rpath, lpath
147
148    patch_shape = kwargs.pop("patch_shape")
149    if patch_shape is not None:
150        if len(patch_shape) == 3:
151            if patch_shape[0] != 1:
152                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
153            patch_shape = patch_shape[1:]
154        assert len(patch_shape) == 2
155
156    if isinstance(raw_paths, str):
157        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
158        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
159
160    elif raw_key is None:
161        assert label_key is None
162        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
163        assert len(raw_paths) == len(label_paths)
164        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
165
166    else:
167        ds = []
168        n_samples = kwargs.pop("n_samples", None)
169        samples_per_ds = (
170            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
171        )
172        if roi is None:
173            roi = len(raw_paths) * [None]
174        assert len(roi) == len(raw_paths)
175        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
176            print(raw_path, label_path, this_roi)
177            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
178            dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs)
179            ds.append(dset)
180        ds = ConcatDataset(*ds)
181
182    return ds
183
184
185def _get_default_transform(path, key, is_seg_dataset, ndim):
186    if is_seg_dataset and ndim is None:
187        shape = load_data(path, key).shape
188        if len(shape) == 2:
189            ndim = 2
190        else:
191            # heuristics to figure out whether to use default 3d
192            # or default anisotropic augmentations
193            ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
194
195    elif is_seg_dataset and ndim is not None:
196        pass
197
198    else:
199        ndim = 2
200
201    return get_augmentations(ndim)
202
203
204def default_segmentation_loader(
205    raw_paths: Union[List[Any], str, os.PathLike],
206    raw_key: Optional[str],
207    label_paths: Union[List[Any], str, os.PathLike],
208    label_key: Optional[str],
209    batch_size: int,
210    patch_shape: Tuple[int, ...],
211    label_transform: Optional[Callable] = None,
212    label_transform2: Optional[Callable] = None,
213    raw_transform: Optional[Callable] = None,
214    transform: Optional[Callable] = None,
215    dtype: torch.device = torch.float32,
216    label_dtype: torch.device = torch.float32,
217    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
218    n_samples: Optional[int] = None,
219    sampler: Optional[Callable] = None,
220    ndim: Optional[int] = None,
221    is_seg_dataset: Optional[bool] = None,
222    with_channels: bool = False,
223    with_label_channels: bool = False,
224    verify_paths: bool = True,
225    with_padding: bool = True,
226    z_ext: Optional[int] = None,
227    **loader_kwargs,
228) -> torch.utils.data.DataLoader:
229    """Get data loader for training a segmentation network.
230
231    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
232    on the data formats that are supported.
233
234    Args:
235        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
236        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
237        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
238        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
239        batch_size: The batch size for the data loader.
240        patch_shape: The patch shape for the training samples.
241        label_transform: Transformation applied to the label data of a sample,
242            before applying augmentations via `transform`.
243        label_transform2: Transformation applied to the label data of a sample,
244            after applying augmentations via `transform`.
245        transform: Transformation applied to both the raw data and label data of a sample.
246            This can be used to implement data augmentations.
247        dtype: The return data type of the raw data.
248        label_dtype: The return data type of the label data.
249        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
250        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
251        sampler: Sampler for rejecting samples according to a defined criterion.
252            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
253        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
254        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
255            If None, the type of dataset will be derived from the data.
256        with_channels: Whether the raw data has channels.
257        with_label_channels: Whether the label data has channels.
258        verify_paths: Whether to verify all paths before creating the dataset.
259        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
260        z_ext: Extra bounding box for loading the data across z.
261        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
262
263    Returns:
264        The torch data loader.
265    """
266    ds = default_segmentation_dataset(
267        raw_paths=raw_paths,
268        raw_key=raw_key,
269        label_paths=label_paths,
270        label_key=label_key,
271        patch_shape=patch_shape,
272        label_transform=label_transform,
273        label_transform2=label_transform2,
274        raw_transform=raw_transform,
275        transform=transform,
276        dtype=dtype,
277        label_dtype=label_dtype,
278        rois=rois,
279        n_samples=n_samples,
280        sampler=sampler,
281        ndim=ndim,
282        is_seg_dataset=is_seg_dataset,
283        with_channels=with_channels,
284        with_label_channels=with_label_channels,
285        with_padding=with_padding,
286        z_ext=z_ext,
287        verify_paths=verify_paths,
288    )
289    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
290
291
292def default_segmentation_dataset(
293    raw_paths: Union[List[Any], str, os.PathLike],
294    raw_key: Optional[str],
295    label_paths: Union[List[Any], str, os.PathLike],
296    label_key: Optional[str],
297    patch_shape: Tuple[int, ...],
298    label_transform: Optional[Callable] = None,
299    label_transform2: Optional[Callable] = None,
300    raw_transform: Optional[Callable] = None,
301    transform: Optional[Callable] = None,
302    dtype: torch.dtype = torch.float32,
303    label_dtype: torch.dtype = torch.float32,
304    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
305    n_samples: Optional[int] = None,
306    sampler: Optional[Callable] = None,
307    ndim: Optional[int] = None,
308    is_seg_dataset: Optional[bool] = None,
309    with_channels: bool = False,
310    with_label_channels: bool = False,
311    verify_paths: bool = True,
312    with_padding: bool = True,
313    z_ext: Optional[int] = None,
314) -> torch.utils.data.Dataset:
315    """Get data set for training a segmentation network.
316
317    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
318    on the data formats that are supported.
319
320    Args:
321        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
322        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
323        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
324        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
325        patch_shape: The patch shape for the training samples.
326        label_transform: Transformation applied to the label data of a sample,
327            before applying augmentations via `transform`.
328        label_transform2: Transformation applied to the label data of a sample,
329            after applying augmentations via `transform`.
330        transform: Transformation applied to both the raw data and label data of a sample.
331            This can be used to implement data augmentations.
332        dtype: The return data type of the raw data.
333        label_dtype: The return data type of the label data.
334        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
335        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
336        sampler: Sampler for rejecting samples according to a defined criterion.
337            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
338        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
339        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
340            If None, the type of dataset will be derived from the data.
341        with_channels: Whether the raw data has channels.
342        with_label_channels: Whether the label data has channels.
343        verify_paths: Whether to verify all paths before creating the dataset.
344        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
345        z_ext: Extra bounding box for loading the data across z.
346        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
347
348    Returns:
349        The torch data set.
350    """
351    if verify_paths:
352        check_paths(raw_paths, label_paths)
353
354    if is_seg_dataset is None:
355        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
356
357    # We always use a raw transform in the convenience function.
358    if raw_transform is None:
359        raw_transform = get_raw_transform()
360
361    # We always use augmentations in the convenience function.
362    if transform is None:
363        transform = _get_default_transform(
364            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
365        )
366
367    if is_seg_dataset:
368        ds = _load_segmentation_dataset(
369            raw_paths,
370            raw_key,
371            label_paths,
372            label_key,
373            patch_shape=patch_shape,
374            raw_transform=raw_transform,
375            label_transform=label_transform,
376            label_transform2=label_transform2,
377            transform=transform,
378            rois=rois,
379            n_samples=n_samples,
380            sampler=sampler,
381            ndim=ndim,
382            dtype=dtype,
383            label_dtype=label_dtype,
384            with_channels=with_channels,
385            with_label_channels=with_label_channels,
386            with_padding=with_padding,
387            z_ext=z_ext,
388        )
389
390    else:
391        ds = _load_image_collection_dataset(
392            raw_paths,
393            raw_key,
394            label_paths,
395            label_key,
396            roi=rois,
397            patch_shape=patch_shape,
398            label_transform=label_transform,
399            raw_transform=raw_transform,
400            label_transform2=label_transform2,
401            transform=transform,
402            n_samples=n_samples,
403            sampler=sampler,
404            dtype=dtype,
405            label_dtype=label_dtype,
406            with_padding=with_padding,
407        )
408
409    return ds
410
411
412def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader:
413    """@private
414    """
415    pin_memory = loader_kwargs.pop("pin_memory", True)
416    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs)
417    # monkey patch shuffle attribute to the loader
418    loader.shuffle = loader_kwargs.get("shuffle", False)
419    return loader
420
421
422#
423# convenience functions for segmentation trainers
424#
425
426
427def default_segmentation_trainer(
428    name: str,
429    model: torch.nn.Module,
430    train_loader: DataLoader,
431    val_loader: DataLoader,
432    loss: Optional[torch.nn.Module] = None,
433    metric: Optional[Callable] = None,
434    learning_rate: float = 1e-3,
435    device: Optional[Union[str, torch.device]] = None,
436    log_image_interval: int = 100,
437    mixed_precision: bool = True,
438    early_stopping: Optional[int] = None,
439    logger=TensorboardLogger,
440    logger_kwargs: Optional[Dict[str, Any]] = None,
441    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
442    optimizer_kwargs: Dict[str, Any] = {},
443    trainer_class=DefaultTrainer,
444    id_: Optional[str] = None,
445    save_root: Optional[str] = None,
446    compile_model: Optional[Union[bool, str]] = None,
447    rank: Optional[int] = None,
448):
449    """Get a trainer for a segmentation network.
450
451    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
452    By default, it uses the dice score as loss and metric.
453    This can be changed by passing arguments for `loss` and/or `metric`.
454    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
455
456    Here's an example for training a 2D U-Net with this function:
457    ```python
458    import torch_em
459    from torch_em.model import UNet2d
460    from torch_em.data.datasets.light_microscopy import get_dsb_loader
461
462    # The training data will be downloaded to this location.
463    data_root = "/path/to/save/the/training/data"
464    patch_shape = (256, 256)
465    trainer = default_segmentation_trainer(
466        name="unet-training"
467        model=UNet2d(in_channels=1, out_channels=1)
468        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
469        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
470    )
471    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
472    ```
473
474    Args:
475        name: The name of the checkpoint that will be created by the trainer.
476        model: The model to train.
477        train_loader: The data loader containing the training data.
478        val_loader: The data loader containing the validation data.
479        loss: The loss function for training.
480        metric: The metric for validation.
481        learning_rate: The initial learning rate for the AdamW optimizer.
482        device: The torch device to use for training. If None, will use a GPU if available.
483        log_image_interval: The interval for saving images during logging, in training iterations.
484        mixed_precision: Whether to train with mixed precision.
485        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
486        logger: The logger class. Will be instantiated for logging.
487            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
488        logger_kwargs: The keyword arguments for the logger class.
489        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
490        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
491        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
492            but can be set to a custom trainer class to enable custom training procedures.
493        id_: Unique identifier for the trainer. If None then `name` will be used.
494        save_root: The root folder for saving the checkpoint and logs.
495        compile_model: Whether to compile the model before training.
496        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
497
498    Returns:
499        The trainer.
500    """
501    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
502    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
503
504    loss = DiceLoss() if loss is None else loss
505    metric = DiceLoss() if metric is None else metric
506
507    if device is None:
508        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
509    else:
510        device = torch.device(device)
511
512    # CPU does not support mixed precision training.
513    if device.type == "cpu":
514        mixed_precision = False
515
516    return trainer_class(
517        name=name,
518        model=model,
519        train_loader=train_loader,
520        val_loader=val_loader,
521        loss=loss,
522        metric=metric,
523        optimizer=optimizer,
524        device=device,
525        lr_scheduler=scheduler,
526        mixed_precision=mixed_precision,
527        early_stopping=early_stopping,
528        log_image_interval=log_image_interval,
529        logger=logger,
530        logger_kwargs=logger_kwargs,
531        id_=id_,
532        save_root=save_root,
533        compile_model=compile_model,
534        rank=rank,
535    )
def default_segmentation_loader( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], batch_size: int, patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.device = torch.float32, label_dtype: torch.device = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
205def default_segmentation_loader(
206    raw_paths: Union[List[Any], str, os.PathLike],
207    raw_key: Optional[str],
208    label_paths: Union[List[Any], str, os.PathLike],
209    label_key: Optional[str],
210    batch_size: int,
211    patch_shape: Tuple[int, ...],
212    label_transform: Optional[Callable] = None,
213    label_transform2: Optional[Callable] = None,
214    raw_transform: Optional[Callable] = None,
215    transform: Optional[Callable] = None,
216    dtype: torch.device = torch.float32,
217    label_dtype: torch.device = torch.float32,
218    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
219    n_samples: Optional[int] = None,
220    sampler: Optional[Callable] = None,
221    ndim: Optional[int] = None,
222    is_seg_dataset: Optional[bool] = None,
223    with_channels: bool = False,
224    with_label_channels: bool = False,
225    verify_paths: bool = True,
226    with_padding: bool = True,
227    z_ext: Optional[int] = None,
228    **loader_kwargs,
229) -> torch.utils.data.DataLoader:
230    """Get data loader for training a segmentation network.
231
232    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
233    on the data formats that are supported.
234
235    Args:
236        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
237        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
238        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
239        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
240        batch_size: The batch size for the data loader.
241        patch_shape: The patch shape for the training samples.
242        label_transform: Transformation applied to the label data of a sample,
243            before applying augmentations via `transform`.
244        label_transform2: Transformation applied to the label data of a sample,
245            after applying augmentations via `transform`.
246        transform: Transformation applied to both the raw data and label data of a sample.
247            This can be used to implement data augmentations.
248        dtype: The return data type of the raw data.
249        label_dtype: The return data type of the label data.
250        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
251        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
252        sampler: Sampler for rejecting samples according to a defined criterion.
253            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
254        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
255        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
256            If None, the type of dataset will be derived from the data.
257        with_channels: Whether the raw data has channels.
258        with_label_channels: Whether the label data has channels.
259        verify_paths: Whether to verify all paths before creating the dataset.
260        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
261        z_ext: Extra bounding box for loading the data across z.
262        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
263
264    Returns:
265        The torch data loader.
266    """
267    ds = default_segmentation_dataset(
268        raw_paths=raw_paths,
269        raw_key=raw_key,
270        label_paths=label_paths,
271        label_key=label_key,
272        patch_shape=patch_shape,
273        label_transform=label_transform,
274        label_transform2=label_transform2,
275        raw_transform=raw_transform,
276        transform=transform,
277        dtype=dtype,
278        label_dtype=label_dtype,
279        rois=rois,
280        n_samples=n_samples,
281        sampler=sampler,
282        ndim=ndim,
283        is_seg_dataset=is_seg_dataset,
284        with_channels=with_channels,
285        with_label_channels=with_label_channels,
286        with_padding=with_padding,
287        z_ext=z_ext,
288        verify_paths=verify_paths,
289    )
290    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get data loader for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the underlying dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch data loader.

def default_segmentation_dataset( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None) -> torch.utils.data.dataset.Dataset:
293def default_segmentation_dataset(
294    raw_paths: Union[List[Any], str, os.PathLike],
295    raw_key: Optional[str],
296    label_paths: Union[List[Any], str, os.PathLike],
297    label_key: Optional[str],
298    patch_shape: Tuple[int, ...],
299    label_transform: Optional[Callable] = None,
300    label_transform2: Optional[Callable] = None,
301    raw_transform: Optional[Callable] = None,
302    transform: Optional[Callable] = None,
303    dtype: torch.dtype = torch.float32,
304    label_dtype: torch.dtype = torch.float32,
305    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
306    n_samples: Optional[int] = None,
307    sampler: Optional[Callable] = None,
308    ndim: Optional[int] = None,
309    is_seg_dataset: Optional[bool] = None,
310    with_channels: bool = False,
311    with_label_channels: bool = False,
312    verify_paths: bool = True,
313    with_padding: bool = True,
314    z_ext: Optional[int] = None,
315) -> torch.utils.data.Dataset:
316    """Get data set for training a segmentation network.
317
318    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
319    on the data formats that are supported.
320
321    Args:
322        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
323        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
324        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
325        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
326        patch_shape: The patch shape for the training samples.
327        label_transform: Transformation applied to the label data of a sample,
328            before applying augmentations via `transform`.
329        label_transform2: Transformation applied to the label data of a sample,
330            after applying augmentations via `transform`.
331        transform: Transformation applied to both the raw data and label data of a sample.
332            This can be used to implement data augmentations.
333        dtype: The return data type of the raw data.
334        label_dtype: The return data type of the label data.
335        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
336        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
337        sampler: Sampler for rejecting samples according to a defined criterion.
338            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
339        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
340        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
341            If None, the type of dataset will be derived from the data.
342        with_channels: Whether the raw data has channels.
343        with_label_channels: Whether the label data has channels.
344        verify_paths: Whether to verify all paths before creating the dataset.
345        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
346        z_ext: Extra bounding box for loading the data across z.
347        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
348
349    Returns:
350        The torch data set.
351    """
352    if verify_paths:
353        check_paths(raw_paths, label_paths)
354
355    if is_seg_dataset is None:
356        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
357
358    # We always use a raw transform in the convenience function.
359    if raw_transform is None:
360        raw_transform = get_raw_transform()
361
362    # We always use augmentations in the convenience function.
363    if transform is None:
364        transform = _get_default_transform(
365            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
366        )
367
368    if is_seg_dataset:
369        ds = _load_segmentation_dataset(
370            raw_paths,
371            raw_key,
372            label_paths,
373            label_key,
374            patch_shape=patch_shape,
375            raw_transform=raw_transform,
376            label_transform=label_transform,
377            label_transform2=label_transform2,
378            transform=transform,
379            rois=rois,
380            n_samples=n_samples,
381            sampler=sampler,
382            ndim=ndim,
383            dtype=dtype,
384            label_dtype=label_dtype,
385            with_channels=with_channels,
386            with_label_channels=with_label_channels,
387            with_padding=with_padding,
388            z_ext=z_ext,
389        )
390
391    else:
392        ds = _load_image_collection_dataset(
393            raw_paths,
394            raw_key,
395            label_paths,
396            label_key,
397            roi=rois,
398            patch_shape=patch_shape,
399            label_transform=label_transform,
400            raw_transform=raw_transform,
401            label_transform2=label_transform2,
402            transform=transform,
403            n_samples=n_samples,
404            sampler=sampler,
405            dtype=dtype,
406            label_dtype=label_dtype,
407            with_padding=with_padding,
408        )
409
410    return ds

Get data set for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch data set.

def default_segmentation_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Optional[torch.nn.modules.module.Module] = None, metric: Optional[Callable] = None, learning_rate: float = 0.001, device: Union[str, torch.device, NoneType] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs: Dict[str, Any] = {'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs: Dict[str, Any] = {}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None):
428def default_segmentation_trainer(
429    name: str,
430    model: torch.nn.Module,
431    train_loader: DataLoader,
432    val_loader: DataLoader,
433    loss: Optional[torch.nn.Module] = None,
434    metric: Optional[Callable] = None,
435    learning_rate: float = 1e-3,
436    device: Optional[Union[str, torch.device]] = None,
437    log_image_interval: int = 100,
438    mixed_precision: bool = True,
439    early_stopping: Optional[int] = None,
440    logger=TensorboardLogger,
441    logger_kwargs: Optional[Dict[str, Any]] = None,
442    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
443    optimizer_kwargs: Dict[str, Any] = {},
444    trainer_class=DefaultTrainer,
445    id_: Optional[str] = None,
446    save_root: Optional[str] = None,
447    compile_model: Optional[Union[bool, str]] = None,
448    rank: Optional[int] = None,
449):
450    """Get a trainer for a segmentation network.
451
452    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
453    By default, it uses the dice score as loss and metric.
454    This can be changed by passing arguments for `loss` and/or `metric`.
455    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
456
457    Here's an example for training a 2D U-Net with this function:
458    ```python
459    import torch_em
460    from torch_em.model import UNet2d
461    from torch_em.data.datasets.light_microscopy import get_dsb_loader
462
463    # The training data will be downloaded to this location.
464    data_root = "/path/to/save/the/training/data"
465    patch_shape = (256, 256)
466    trainer = default_segmentation_trainer(
467        name="unet-training"
468        model=UNet2d(in_channels=1, out_channels=1)
469        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
470        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
471    )
472    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
473    ```
474
475    Args:
476        name: The name of the checkpoint that will be created by the trainer.
477        model: The model to train.
478        train_loader: The data loader containing the training data.
479        val_loader: The data loader containing the validation data.
480        loss: The loss function for training.
481        metric: The metric for validation.
482        learning_rate: The initial learning rate for the AdamW optimizer.
483        device: The torch device to use for training. If None, will use a GPU if available.
484        log_image_interval: The interval for saving images during logging, in training iterations.
485        mixed_precision: Whether to train with mixed precision.
486        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
487        logger: The logger class. Will be instantiated for logging.
488            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
489        logger_kwargs: The keyword arguments for the logger class.
490        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
491        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
492        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
493            but can be set to a custom trainer class to enable custom training procedures.
494        id_: Unique identifier for the trainer. If None then `name` will be used.
495        save_root: The root folder for saving the checkpoint and logs.
496        compile_model: Whether to compile the model before training.
497        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
498
499    Returns:
500        The trainer.
501    """
502    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
503    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
504
505    loss = DiceLoss() if loss is None else loss
506    metric = DiceLoss() if metric is None else metric
507
508    if device is None:
509        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
510    else:
511        device = torch.device(device)
512
513    # CPU does not support mixed precision training.
514    if device.type == "cpu":
515        mixed_precision = False
516
517    return trainer_class(
518        name=name,
519        model=model,
520        train_loader=train_loader,
521        val_loader=val_loader,
522        loss=loss,
523        metric=metric,
524        optimizer=optimizer,
525        device=device,
526        lr_scheduler=scheduler,
527        mixed_precision=mixed_precision,
528        early_stopping=early_stopping,
529        log_image_interval=log_image_interval,
530        logger=logger,
531        logger_kwargs=logger_kwargs,
532        id_=id_,
533        save_root=save_root,
534        compile_model=compile_model,
535        rank=rank,
536    )

Get a trainer for a segmentation network.

It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau. By default, it uses the dice score as loss and metric. This can be changed by passing arguments for loss and/or metric. See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.

Here's an example for training a 2D U-Net with this function:

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
    name="unet-training"
    model=UNet2d(in_channels=1, out_channels=1)
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • model: The model to train.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • loss: The loss function for training.
  • metric: The metric for validation.
  • learning_rate: The initial learning rate for the AdamW optimizer.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
  • optimizer_kwargs: The keyword arguments for the AdamW optimizer.
  • trainer_class: The trainer class. Uses torch_em.trainer.DefaultTrainer by default, but can be set to a custom trainer class to enable custom training procedures.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
Returns:

The trainer.