torch_em.segmentation

  1import os
  2from glob import glob
  3from typing import Any, Dict, Optional, Union, Tuple, List, Callable
  4
  5import torch
  6import torch.utils.data
  7from torch.utils.data import DataLoader
  8
  9from .loss import DiceLoss
 10from .util import load_data
 11from .trainer import DefaultTrainer
 12from .trainer.tensorboard_logger import TensorboardLogger
 13from .transform import get_augmentations, get_raw_transform
 14from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
 15
 16
 17# TODO add a heuristic to estimate this from the number of epochs
 18DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5}
 19"""@private
 20"""
 21
 22
 23#
 24# convenience functions for segmentation loaders
 25#
 26
 27# TODO implement balanced and make it the default
 28# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"):
 29def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
 30    """@private
 31    """
 32    assert split in ("balanced", "uniform")
 33    n_datasets = len(raw_paths)
 34    if split == "uniform":
 35        # even distribution of samples to datasets
 36        samples_per_ds = n_samples // n_datasets
 37        divider = n_samples % n_datasets
 38        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
 39    else:
 40        # distribution of samples to dataset based on the dataset lens
 41        raise NotImplementedError
 42
 43
 44def check_paths(raw_paths, label_paths):
 45    """@private
 46    """
 47    if not isinstance(raw_paths, type(label_paths)):
 48        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
 49
 50    def _check_path(path):
 51        if isinstance(path, str):
 52            if not os.path.exists(path):
 53                raise ValueError(f"Could not find path {path}")
 54        else:
 55            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
 56            for per_path in path:
 57                if not os.path.exists(per_path):
 58                    raise ValueError(f"Could not find path {per_path}")
 59
 60    if isinstance(raw_paths, str):
 61        _check_path(raw_paths)
 62        _check_path(label_paths)
 63    else:
 64        if len(raw_paths) != len(label_paths):
 65            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
 66        for rp, lp in zip(raw_paths, label_paths):
 67            _check_path(rp)
 68            _check_path(lp)
 69
 70
 71# Check if we can load the data as SegmentationDataset.
 72def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
 73    """@private
 74    """
 75    def _can_open(path, key):
 76        try:
 77            load_data(path, key)
 78            return True
 79        except Exception:
 80            return False
 81
 82    if isinstance(raw_paths, str):
 83        can_open_raw = _can_open(raw_paths, raw_key)
 84        can_open_label = _can_open(label_paths, label_key)
 85    else:
 86        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
 87        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
 88            raise ValueError("Inconsistent raw data")
 89        can_open_raw = can_open_raw[0]
 90
 91        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
 92        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
 93            raise ValueError("Inconsistent label data")
 94        can_open_label = can_open_label[0]
 95
 96    if can_open_raw != can_open_label:
 97        raise ValueError("Inconsistent raw and label data")
 98
 99    return can_open_raw
100
101
102def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs):
103    rois = kwargs.pop("rois", None)
104    if isinstance(raw_paths, str):
105        if rois is not None:
106            assert isinstance(rois, (tuple, slice))
107            if isinstance(rois, tuple):
108                assert all(isinstance(roi, slice) for roi in rois)
109        ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
110    else:
111        assert len(raw_paths) > 0
112        if rois is not None:
113            assert len(rois) == len(label_paths)
114            assert all(isinstance(roi, tuple) for roi in rois), f"{rois}"
115        n_samples = kwargs.pop("n_samples", None)
116
117        samples_per_ds = (
118            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
119        )
120        ds = []
121        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
122            roi = None if rois is None else rois[i]
123            dset = SegmentationDataset(
124                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs
125            )
126            ds.append(dset)
127        ds = ConcatDataset(*ds)
128    return ds
129
130
131def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs):
132    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
133        rpath = glob(os.path.join(rpath, rkey))
134        rpath.sort()
135        if len(rpath) == 0:
136            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
137
138        lpath = glob(os.path.join(lpath, lkey))
139        lpath.sort()
140        if len(rpath) != len(lpath):
141            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
142
143        if this_roi is not None:
144            rpath, lpath = rpath[roi], lpath[roi]
145
146        return rpath, lpath
147
148    patch_shape = kwargs.pop("patch_shape")
149    if patch_shape is not None:
150        if len(patch_shape) == 3:
151            if patch_shape[0] != 1:
152                raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
153            patch_shape = patch_shape[1:]
154        assert len(patch_shape) == 2
155
156    if isinstance(raw_paths, str):
157        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
158        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
159
160    elif raw_key is None:
161        assert label_key is None
162        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
163        assert len(raw_paths) == len(label_paths)
164        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
165
166    else:
167        ds = []
168        n_samples = kwargs.pop("n_samples", None)
169        samples_per_ds = (
170            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
171        )
172        if roi is None:
173            roi = len(raw_paths) * [None]
174        assert len(roi) == len(raw_paths)
175        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
176            print(raw_path, label_path, this_roi)
177            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
178            dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs)
179            ds.append(dset)
180        ds = ConcatDataset(*ds)
181
182    return ds
183
184
185def _get_default_transform(path, key, is_seg_dataset, ndim):
186    if is_seg_dataset and ndim is None:
187        shape = load_data(path, key).shape
188        if len(shape) == 2:
189            ndim = 2
190        else:
191            # heuristics to figure out whether to use default 3d
192            # or default anisotropic augmentations
193            ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
194
195    elif is_seg_dataset and ndim is not None:
196        pass
197
198    else:
199        ndim = 2
200
201    return get_augmentations(ndim)
202
203
204def default_segmentation_loader(
205    raw_paths: Union[List[Any], str, os.PathLike],
206    raw_key: Optional[str],
207    label_paths: Union[List[Any], str, os.PathLike],
208    label_key: Optional[str],
209    batch_size: int,
210    patch_shape: Tuple[int, ...],
211    label_transform: Optional[Callable] = None,
212    label_transform2: Optional[Callable] = None,
213    raw_transform: Optional[Callable] = None,
214    transform: Optional[Callable] = None,
215    dtype: torch.device = torch.float32,
216    label_dtype: torch.device = torch.float32,
217    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
218    n_samples: Optional[int] = None,
219    sampler: Optional[Callable] = None,
220    ndim: Optional[int] = None,
221    is_seg_dataset: Optional[bool] = None,
222    with_channels: bool = False,
223    with_label_channels: bool = False,
224    verify_paths: bool = True,
225    with_padding: bool = True,
226    z_ext: Optional[int] = None,
227    **loader_kwargs,
228) -> torch.utils.data.DataLoader:
229    """Get data loader for training a segmentation network.
230
231    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
232    on the data formats that are supported.
233
234    Args:
235        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
236        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
237        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
238        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
239        batch_size: The batch size for the data loader.
240        patch_shape: The patch shape for the training samples.
241        label_transform: Transformation applied to the label data of a sample,
242            before applying augmentations via `transform`.
243        label_transform2: Transformation applied to the label data of a sample,
244            after applying augmentations via `transform`.
245        raw_transform: Transformation applied to the raw data of a sample,
246            before applying augmentations via `transform`.
247        transform: Transformation applied to both the raw data and label data of a sample.
248            This can be used to implement data augmentations.
249        dtype: The return data type of the raw data.
250        label_dtype: The return data type of the label data.
251        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
252        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
253        sampler: Sampler for rejecting samples according to a defined criterion.
254            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
255        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
256        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
257            If None, the type of dataset will be derived from the data.
258        with_channels: Whether the raw data has channels.
259        with_label_channels: Whether the label data has channels.
260        verify_paths: Whether to verify all paths before creating the dataset.
261        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
262        z_ext: Extra bounding box for loading the data across z.
263        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
264
265    Returns:
266        The torch data loader.
267    """
268    ds = default_segmentation_dataset(
269        raw_paths=raw_paths,
270        raw_key=raw_key,
271        label_paths=label_paths,
272        label_key=label_key,
273        patch_shape=patch_shape,
274        label_transform=label_transform,
275        label_transform2=label_transform2,
276        raw_transform=raw_transform,
277        transform=transform,
278        dtype=dtype,
279        label_dtype=label_dtype,
280        rois=rois,
281        n_samples=n_samples,
282        sampler=sampler,
283        ndim=ndim,
284        is_seg_dataset=is_seg_dataset,
285        with_channels=with_channels,
286        with_label_channels=with_label_channels,
287        with_padding=with_padding,
288        z_ext=z_ext,
289        verify_paths=verify_paths,
290    )
291    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
292
293
294def default_segmentation_dataset(
295    raw_paths: Union[List[Any], str, os.PathLike],
296    raw_key: Optional[str],
297    label_paths: Union[List[Any], str, os.PathLike],
298    label_key: Optional[str],
299    patch_shape: Tuple[int, ...],
300    label_transform: Optional[Callable] = None,
301    label_transform2: Optional[Callable] = None,
302    raw_transform: Optional[Callable] = None,
303    transform: Optional[Callable] = None,
304    dtype: torch.dtype = torch.float32,
305    label_dtype: torch.dtype = torch.float32,
306    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
307    n_samples: Optional[int] = None,
308    sampler: Optional[Callable] = None,
309    ndim: Optional[int] = None,
310    is_seg_dataset: Optional[bool] = None,
311    with_channels: bool = False,
312    with_label_channels: bool = False,
313    verify_paths: bool = True,
314    with_padding: bool = True,
315    z_ext: Optional[int] = None,
316) -> torch.utils.data.Dataset:
317    """Get data set for training a segmentation network.
318
319    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
320    on the data formats that are supported.
321
322    Args:
323        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
324        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
325        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
326        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
327        patch_shape: The patch shape for the training samples.
328        label_transform: Transformation applied to the label data of a sample,
329            before applying augmentations via `transform`.
330        label_transform2: Transformation applied to the label data of a sample,
331            after applying augmentations via `transform`.
332        raw_transform: Transformation applied to the raw data of a sample,
333            before applying augmentations via `transform`.
334        transform: Transformation applied to both the raw data and label data of a sample.
335            This can be used to implement data augmentations.
336        dtype: The return data type of the raw data.
337        label_dtype: The return data type of the label data.
338        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
339        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
340        sampler: Sampler for rejecting samples according to a defined criterion.
341            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
342        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
343        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
344            If None, the type of dataset will be derived from the data.
345        with_channels: Whether the raw data has channels.
346        with_label_channels: Whether the label data has channels.
347        verify_paths: Whether to verify all paths before creating the dataset.
348        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
349        z_ext: Extra bounding box for loading the data across z.
350        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
351
352    Returns:
353        The torch dataset.
354    """
355    if verify_paths:
356        check_paths(raw_paths, label_paths)
357
358    if is_seg_dataset is None:
359        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
360
361    # We always use a raw transform in the convenience function.
362    if raw_transform is None:
363        raw_transform = get_raw_transform()
364
365    # We always use augmentations in the convenience function.
366    if transform is None:
367        transform = _get_default_transform(
368            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
369        )
370
371    if is_seg_dataset:
372        ds = _load_segmentation_dataset(
373            raw_paths,
374            raw_key,
375            label_paths,
376            label_key,
377            patch_shape=patch_shape,
378            raw_transform=raw_transform,
379            label_transform=label_transform,
380            label_transform2=label_transform2,
381            transform=transform,
382            rois=rois,
383            n_samples=n_samples,
384            sampler=sampler,
385            ndim=ndim,
386            dtype=dtype,
387            label_dtype=label_dtype,
388            with_channels=with_channels,
389            with_label_channels=with_label_channels,
390            with_padding=with_padding,
391            z_ext=z_ext,
392        )
393
394    else:
395        ds = _load_image_collection_dataset(
396            raw_paths,
397            raw_key,
398            label_paths,
399            label_key,
400            roi=rois,
401            patch_shape=patch_shape,
402            label_transform=label_transform,
403            raw_transform=raw_transform,
404            label_transform2=label_transform2,
405            transform=transform,
406            n_samples=n_samples,
407            sampler=sampler,
408            dtype=dtype,
409            label_dtype=label_dtype,
410            with_padding=with_padding,
411        )
412
413    return ds
414
415
416def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader:
417    """@private
418    """
419    pin_memory = loader_kwargs.pop("pin_memory", True)
420    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs)
421    # monkey patch shuffle attribute to the loader
422    loader.shuffle = loader_kwargs.get("shuffle", False)
423    return loader
424
425
426#
427# convenience functions for segmentation trainers
428#
429
430
431def default_segmentation_trainer(
432    name: str,
433    model: torch.nn.Module,
434    train_loader: DataLoader,
435    val_loader: DataLoader,
436    loss: Optional[torch.nn.Module] = None,
437    metric: Optional[Callable] = None,
438    learning_rate: float = 1e-3,
439    device: Optional[Union[str, torch.device]] = None,
440    log_image_interval: int = 100,
441    mixed_precision: bool = True,
442    early_stopping: Optional[int] = None,
443    logger=TensorboardLogger,
444    logger_kwargs: Optional[Dict[str, Any]] = None,
445    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
446    optimizer_kwargs: Dict[str, Any] = {},
447    trainer_class=DefaultTrainer,
448    id_: Optional[str] = None,
449    save_root: Optional[str] = None,
450    compile_model: Optional[Union[bool, str]] = None,
451    rank: Optional[int] = None,
452):
453    """Get a trainer for a segmentation network.
454
455    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
456    By default, it uses the dice score as loss and metric.
457    This can be changed by passing arguments for `loss` and/or `metric`.
458    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
459
460    Here's an example for training a 2D U-Net with this function:
461    ```python
462    import torch_em
463    from torch_em.model import UNet2d
464    from torch_em.data.datasets.light_microscopy import get_dsb_loader
465
466    # The training data will be downloaded to this location.
467    data_root = "/path/to/save/the/training/data"
468    patch_shape = (256, 256)
469    trainer = default_segmentation_trainer(
470        name="unet-training"
471        model=UNet2d(in_channels=1, out_channels=1)
472        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
473        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
474    )
475    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
476    ```
477
478    Args:
479        name: The name of the checkpoint that will be created by the trainer.
480        model: The model to train.
481        train_loader: The data loader containing the training data.
482        val_loader: The data loader containing the validation data.
483        loss: The loss function for training.
484        metric: The metric for validation.
485        learning_rate: The initial learning rate for the AdamW optimizer.
486        device: The torch device to use for training. If None, will use a GPU if available.
487        log_image_interval: The interval for saving images during logging, in training iterations.
488        mixed_precision: Whether to train with mixed precision.
489        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
490        logger: The logger class. Will be instantiated for logging.
491            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
492        logger_kwargs: The keyword arguments for the logger class.
493        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
494        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
495        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
496            but can be set to a custom trainer class to enable custom training procedures.
497        id_: Unique identifier for the trainer. If None then `name` will be used.
498        save_root: The root folder for saving the checkpoint and logs.
499        compile_model: Whether to compile the model before training.
500        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
501
502    Returns:
503        The trainer.
504    """
505    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
506    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
507
508    loss = DiceLoss() if loss is None else loss
509    metric = DiceLoss() if metric is None else metric
510
511    if device is None:
512        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
513    else:
514        device = torch.device(device)
515
516    # CPU does not support mixed precision training.
517    if device.type == "cpu":
518        mixed_precision = False
519
520    return trainer_class(
521        name=name,
522        model=model,
523        train_loader=train_loader,
524        val_loader=val_loader,
525        loss=loss,
526        metric=metric,
527        optimizer=optimizer,
528        device=device,
529        lr_scheduler=scheduler,
530        mixed_precision=mixed_precision,
531        early_stopping=early_stopping,
532        log_image_interval=log_image_interval,
533        logger=logger,
534        logger_kwargs=logger_kwargs,
535        id_=id_,
536        save_root=save_root,
537        compile_model=compile_model,
538        rank=rank,
539    )
def default_segmentation_loader( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], batch_size: int, patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.device = torch.float32, label_dtype: torch.device = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
205def default_segmentation_loader(
206    raw_paths: Union[List[Any], str, os.PathLike],
207    raw_key: Optional[str],
208    label_paths: Union[List[Any], str, os.PathLike],
209    label_key: Optional[str],
210    batch_size: int,
211    patch_shape: Tuple[int, ...],
212    label_transform: Optional[Callable] = None,
213    label_transform2: Optional[Callable] = None,
214    raw_transform: Optional[Callable] = None,
215    transform: Optional[Callable] = None,
216    dtype: torch.device = torch.float32,
217    label_dtype: torch.device = torch.float32,
218    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
219    n_samples: Optional[int] = None,
220    sampler: Optional[Callable] = None,
221    ndim: Optional[int] = None,
222    is_seg_dataset: Optional[bool] = None,
223    with_channels: bool = False,
224    with_label_channels: bool = False,
225    verify_paths: bool = True,
226    with_padding: bool = True,
227    z_ext: Optional[int] = None,
228    **loader_kwargs,
229) -> torch.utils.data.DataLoader:
230    """Get data loader for training a segmentation network.
231
232    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
233    on the data formats that are supported.
234
235    Args:
236        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
237        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
238        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
239        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
240        batch_size: The batch size for the data loader.
241        patch_shape: The patch shape for the training samples.
242        label_transform: Transformation applied to the label data of a sample,
243            before applying augmentations via `transform`.
244        label_transform2: Transformation applied to the label data of a sample,
245            after applying augmentations via `transform`.
246        raw_transform: Transformation applied to the raw data of a sample,
247            before applying augmentations via `transform`.
248        transform: Transformation applied to both the raw data and label data of a sample.
249            This can be used to implement data augmentations.
250        dtype: The return data type of the raw data.
251        label_dtype: The return data type of the label data.
252        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
253        n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`.
254        sampler: Sampler for rejecting samples according to a defined criterion.
255            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
256        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
257        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
258            If None, the type of dataset will be derived from the data.
259        with_channels: Whether the raw data has channels.
260        with_label_channels: Whether the label data has channels.
261        verify_paths: Whether to verify all paths before creating the dataset.
262        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
263        z_ext: Extra bounding box for loading the data across z.
264        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
265
266    Returns:
267        The torch data loader.
268    """
269    ds = default_segmentation_dataset(
270        raw_paths=raw_paths,
271        raw_key=raw_key,
272        label_paths=label_paths,
273        label_key=label_key,
274        patch_shape=patch_shape,
275        label_transform=label_transform,
276        label_transform2=label_transform2,
277        raw_transform=raw_transform,
278        transform=transform,
279        dtype=dtype,
280        label_dtype=label_dtype,
281        rois=rois,
282        n_samples=n_samples,
283        sampler=sampler,
284        ndim=ndim,
285        is_seg_dataset=is_seg_dataset,
286        with_channels=with_channels,
287        with_label_channels=with_label_channels,
288        with_padding=with_padding,
289        z_ext=z_ext,
290        verify_paths=verify_paths,
291    )
292    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

Get data loader for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • batch_size: The batch size for the data loader.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the underlying dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch data loader.

def default_segmentation_dataset( raw_paths: Union[List[Any], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Any], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int, ...], label_transform: Optional[Callable] = None, label_transform2: Optional[Callable] = None, raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, dtype: torch.dtype = torch.float32, label_dtype: torch.dtype = torch.float32, rois: Union[slice, Tuple[slice, ...], NoneType] = None, n_samples: Optional[int] = None, sampler: Optional[Callable] = None, ndim: Optional[int] = None, is_seg_dataset: Optional[bool] = None, with_channels: bool = False, with_label_channels: bool = False, verify_paths: bool = True, with_padding: bool = True, z_ext: Optional[int] = None) -> torch.utils.data.dataset.Dataset:
295def default_segmentation_dataset(
296    raw_paths: Union[List[Any], str, os.PathLike],
297    raw_key: Optional[str],
298    label_paths: Union[List[Any], str, os.PathLike],
299    label_key: Optional[str],
300    patch_shape: Tuple[int, ...],
301    label_transform: Optional[Callable] = None,
302    label_transform2: Optional[Callable] = None,
303    raw_transform: Optional[Callable] = None,
304    transform: Optional[Callable] = None,
305    dtype: torch.dtype = torch.float32,
306    label_dtype: torch.dtype = torch.float32,
307    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
308    n_samples: Optional[int] = None,
309    sampler: Optional[Callable] = None,
310    ndim: Optional[int] = None,
311    is_seg_dataset: Optional[bool] = None,
312    with_channels: bool = False,
313    with_label_channels: bool = False,
314    verify_paths: bool = True,
315    with_padding: bool = True,
316    z_ext: Optional[int] = None,
317) -> torch.utils.data.Dataset:
318    """Get data set for training a segmentation network.
319
320    See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details
321    on the data formats that are supported.
322
323    Args:
324        raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
325        raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
326        label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
327        label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
328        patch_shape: The patch shape for the training samples.
329        label_transform: Transformation applied to the label data of a sample,
330            before applying augmentations via `transform`.
331        label_transform2: Transformation applied to the label data of a sample,
332            after applying augmentations via `transform`.
333        raw_transform: Transformation applied to the raw data of a sample,
334            before applying augmentations via `transform`.
335        transform: Transformation applied to both the raw data and label data of a sample.
336            This can be used to implement data augmentations.
337        dtype: The return data type of the raw data.
338        label_dtype: The return data type of the label data.
339        rois: Regions of interest in the data.  If given, the data will only be loaded from the corresponding area.
340        n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`.
341        sampler: Sampler for rejecting samples according to a defined criterion.
342            The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
343        ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
344        is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset.
345            If None, the type of dataset will be derived from the data.
346        with_channels: Whether the raw data has channels.
347        with_label_channels: Whether the label data has channels.
348        verify_paths: Whether to verify all paths before creating the dataset.
349        with_padding: Whether to pad samples to `patch_shape` if their shape is smaller.
350        z_ext: Extra bounding box for loading the data across z.
351        loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`.
352
353    Returns:
354        The torch dataset.
355    """
356    if verify_paths:
357        check_paths(raw_paths, label_paths)
358
359    if is_seg_dataset is None:
360        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
361
362    # We always use a raw transform in the convenience function.
363    if raw_transform is None:
364        raw_transform = get_raw_transform()
365
366    # We always use augmentations in the convenience function.
367    if transform is None:
368        transform = _get_default_transform(
369            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
370        )
371
372    if is_seg_dataset:
373        ds = _load_segmentation_dataset(
374            raw_paths,
375            raw_key,
376            label_paths,
377            label_key,
378            patch_shape=patch_shape,
379            raw_transform=raw_transform,
380            label_transform=label_transform,
381            label_transform2=label_transform2,
382            transform=transform,
383            rois=rois,
384            n_samples=n_samples,
385            sampler=sampler,
386            ndim=ndim,
387            dtype=dtype,
388            label_dtype=label_dtype,
389            with_channels=with_channels,
390            with_label_channels=with_label_channels,
391            with_padding=with_padding,
392            z_ext=z_ext,
393        )
394
395    else:
396        ds = _load_image_collection_dataset(
397            raw_paths,
398            raw_key,
399            label_paths,
400            label_key,
401            roi=rois,
402            patch_shape=patch_shape,
403            label_transform=label_transform,
404            raw_transform=raw_transform,
405            label_transform2=label_transform2,
406            transform=transform,
407            n_samples=n_samples,
408            sampler=sampler,
409            dtype=dtype,
410            label_dtype=label_dtype,
411            with_padding=with_padding,
412        )
413
414    return ds

Get data set for training a segmentation network.

See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details on the data formats that are supported.

Arguments:
  • raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
  • raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
  • label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
  • patch_shape: The patch shape for the training samples.
  • label_transform: Transformation applied to the label data of a sample, before applying augmentations via transform.
  • label_transform2: Transformation applied to the label data of a sample, after applying augmentations via transform.
  • raw_transform: Transformation applied to the raw data of a sample, before applying augmentations via transform.
  • transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
  • dtype: The return data type of the raw data.
  • label_dtype: The return data type of the label data.
  • rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
  • n_samples: The length of the dataset. If None, the length will be set to len(raw_paths).
  • sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
  • ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
  • is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
  • with_channels: Whether the raw data has channels.
  • with_label_channels: Whether the label data has channels.
  • verify_paths: Whether to verify all paths before creating the dataset.
  • with_padding: Whether to pad samples to patch_shape if their shape is smaller.
  • z_ext: Extra bounding box for loading the data across z.
  • loader_kwargs: Keyword arguments for torch.utils.data.DataLoder.
Returns:

The torch dataset.

def default_segmentation_trainer( name: str, model: torch.nn.modules.module.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, loss: Optional[torch.nn.modules.module.Module] = None, metric: Optional[Callable] = None, learning_rate: float = 0.001, device: Union[str, torch.device, NoneType] = None, log_image_interval: int = 100, mixed_precision: bool = True, early_stopping: Optional[int] = None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs: Dict[str, Any] = {'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs: Dict[str, Any] = {}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Union[bool, str, NoneType] = None, rank: Optional[int] = None):
432def default_segmentation_trainer(
433    name: str,
434    model: torch.nn.Module,
435    train_loader: DataLoader,
436    val_loader: DataLoader,
437    loss: Optional[torch.nn.Module] = None,
438    metric: Optional[Callable] = None,
439    learning_rate: float = 1e-3,
440    device: Optional[Union[str, torch.device]] = None,
441    log_image_interval: int = 100,
442    mixed_precision: bool = True,
443    early_stopping: Optional[int] = None,
444    logger=TensorboardLogger,
445    logger_kwargs: Optional[Dict[str, Any]] = None,
446    scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS,
447    optimizer_kwargs: Dict[str, Any] = {},
448    trainer_class=DefaultTrainer,
449    id_: Optional[str] = None,
450    save_root: Optional[str] = None,
451    compile_model: Optional[Union[bool, str]] = None,
452    rank: Optional[int] = None,
453):
454    """Get a trainer for a segmentation network.
455
456    It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau.
457    By default, it uses the dice score as loss and metric.
458    This can be changed by passing arguments for `loss` and/or `metric`.
459    See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer.
460
461    Here's an example for training a 2D U-Net with this function:
462    ```python
463    import torch_em
464    from torch_em.model import UNet2d
465    from torch_em.data.datasets.light_microscopy import get_dsb_loader
466
467    # The training data will be downloaded to this location.
468    data_root = "/path/to/save/the/training/data"
469    patch_shape = (256, 256)
470    trainer = default_segmentation_trainer(
471        name="unet-training"
472        model=UNet2d(in_channels=1, out_channels=1)
473        train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
474        val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
475    )
476    trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
477    ```
478
479    Args:
480        name: The name of the checkpoint that will be created by the trainer.
481        model: The model to train.
482        train_loader: The data loader containing the training data.
483        val_loader: The data loader containing the validation data.
484        loss: The loss function for training.
485        metric: The metric for validation.
486        learning_rate: The initial learning rate for the AdamW optimizer.
487        device: The torch device to use for training. If None, will use a GPU if available.
488        log_image_interval: The interval for saving images during logging, in training iterations.
489        mixed_precision: Whether to train with mixed precision.
490        early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
491        logger: The logger class. Will be instantiated for logging.
492            By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`.
493        logger_kwargs: The keyword arguments for the logger class.
494        scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
495        optimizer_kwargs: The keyword arguments for the AdamW optimizer.
496        trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default,
497            but can be set to a custom trainer class to enable custom training procedures.
498        id_: Unique identifier for the trainer. If None then `name` will be used.
499        save_root: The root folder for saving the checkpoint and logs.
500        compile_model: Whether to compile the model before training.
501        rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details.
502
503    Returns:
504        The trainer.
505    """
506    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs)
507    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
508
509    loss = DiceLoss() if loss is None else loss
510    metric = DiceLoss() if metric is None else metric
511
512    if device is None:
513        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
514    else:
515        device = torch.device(device)
516
517    # CPU does not support mixed precision training.
518    if device.type == "cpu":
519        mixed_precision = False
520
521    return trainer_class(
522        name=name,
523        model=model,
524        train_loader=train_loader,
525        val_loader=val_loader,
526        loss=loss,
527        metric=metric,
528        optimizer=optimizer,
529        device=device,
530        lr_scheduler=scheduler,
531        mixed_precision=mixed_precision,
532        early_stopping=early_stopping,
533        log_image_interval=log_image_interval,
534        logger=logger,
535        logger_kwargs=logger_kwargs,
536        id_=id_,
537        save_root=save_root,
538        compile_model=compile_model,
539        rank=rank,
540    )

Get a trainer for a segmentation network.

It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau. By default, it uses the dice score as loss and metric. This can be changed by passing arguments for loss and/or metric. See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.

Here's an example for training a 2D U-Net with this function:

import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader

# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
    name="unet-training"
    model=UNet2d(in_channels=1, out_channels=1)
    train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
    val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4))  # Train for 25.000 iterations.
Arguments:
  • name: The name of the checkpoint that will be created by the trainer.
  • model: The model to train.
  • train_loader: The data loader containing the training data.
  • val_loader: The data loader containing the validation data.
  • loss: The loss function for training.
  • metric: The metric for validation.
  • learning_rate: The initial learning rate for the AdamW optimizer.
  • device: The torch device to use for training. If None, will use a GPU if available.
  • log_image_interval: The interval for saving images during logging, in training iterations.
  • mixed_precision: Whether to train with mixed precision.
  • early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
  • logger: The logger class. Will be instantiated for logging. By default uses torch_em.training.tensorboard_logger.TensorboardLogger.
  • logger_kwargs: The keyword arguments for the logger class.
  • scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
  • optimizer_kwargs: The keyword arguments for the AdamW optimizer.
  • trainer_class: The trainer class. Uses torch_em.trainer.DefaultTrainer by default, but can be set to a custom trainer class to enable custom training procedures.
  • id_: Unique identifier for the trainer. If None then name will be used.
  • save_root: The root folder for saving the checkpoint and logs.
  • compile_model: Whether to compile the model before training.
  • rank: Rank argument for distributed training. See torch_em.multi_gpu_training for details.
Returns:

The trainer.