torch_em.segmentation

  1import os
  2from glob import glob
  3from typing import Any, Dict, Optional
  4
  5import torch
  6import torch.utils.data
  7
  8from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset
  9from .loss import DiceLoss
 10from .trainer import DefaultTrainer
 11from .trainer.tensorboard_logger import TensorboardLogger
 12from .transform import get_augmentations, get_raw_transform
 13from .util import load_data
 14
 15
 16# TODO add a heuristic to estimate this from the number of epochs
 17DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5}
 18
 19
 20#
 21# convenience functions for segmentation loaders
 22#
 23
 24# TODO implement balanced and make it the default
 25# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"):
 26def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
 27    assert split in ("balanced", "uniform")
 28    n_datasets = len(raw_paths)
 29    if split == "uniform":
 30        # even distribution of samples to datasets
 31        samples_per_ds = n_samples // n_datasets
 32        divider = n_samples % n_datasets
 33        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
 34    else:
 35        # distribution of samples to dataset based on the dataset lens
 36        raise NotImplementedError
 37
 38
 39def check_paths(raw_paths, label_paths):
 40    if not isinstance(raw_paths, type(label_paths)):
 41        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
 42
 43    def _check_path(path):
 44        if isinstance(path, str):
 45            if not os.path.exists(path):
 46                raise ValueError(f"Could not find path {path}")
 47        else:
 48            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
 49            for per_path in path:
 50                if not os.path.exists(per_path):
 51                    raise ValueError(f"Could not find path {per_path}")
 52
 53    if isinstance(raw_paths, str):
 54        _check_path(raw_paths)
 55        _check_path(label_paths)
 56    else:
 57        if len(raw_paths) != len(label_paths):
 58            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
 59        for rp, lp in zip(raw_paths, label_paths):
 60            _check_path(rp)
 61            _check_path(lp)
 62
 63
 64def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
 65    """ Check if we can load the data as SegmentationDataset
 66    """
 67
 68    def _can_open(path, key):
 69        try:
 70            load_data(path, key)
 71            return True
 72        except Exception:
 73            return False
 74
 75    if isinstance(raw_paths, str):
 76        can_open_raw = _can_open(raw_paths, raw_key)
 77        can_open_label = _can_open(label_paths, label_key)
 78    else:
 79        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
 80        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
 81            raise ValueError("Inconsistent raw data")
 82        can_open_raw = can_open_raw[0]
 83
 84        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
 85        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
 86            raise ValueError("Inconsistent label data")
 87        can_open_label = can_open_label[0]
 88
 89    if can_open_raw != can_open_label:
 90        raise ValueError("Inconsistent raw and label data")
 91
 92    return can_open_raw
 93
 94
 95def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs):
 96    rois = kwargs.pop("rois", None)
 97    if isinstance(raw_paths, str):
 98        if rois is not None:
 99            assert isinstance(rois, (tuple, slice))
100            if isinstance(rois, tuple):
101                assert all(isinstance(roi, slice) for roi in rois)
102        ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
103    else:
104        assert len(raw_paths) > 0
105        if rois is not None:
106            assert len(rois) == len(label_paths)
107            assert all(isinstance(roi, tuple) for roi in rois), f"{rois}"
108        n_samples = kwargs.pop("n_samples", None)
109
110        samples_per_ds = (
111            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
112        )
113        ds = []
114        for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)):
115            roi = None if rois is None else rois[i]
116            dset = SegmentationDataset(
117                raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs
118            )
119            ds.append(dset)
120        ds = ConcatDataset(*ds)
121    return ds
122
123
124def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs):
125    def _get_paths(rpath, rkey, lpath, lkey, this_roi):
126        rpath = glob(os.path.join(rpath, rkey))
127        rpath.sort()
128        if len(rpath) == 0:
129            raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}")
130        lpath = glob(os.path.join(lpath, lkey))
131        lpath.sort()
132        if len(rpath) != len(lpath):
133            raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}")
134
135        if this_roi is not None:
136            rpath, lpath = rpath[roi], lpath[roi]
137
138        return rpath, lpath
139
140    patch_shape = kwargs.pop("patch_shape")
141    if len(patch_shape) == 3:
142        if patch_shape[0] != 1:
143            raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}")
144        patch_shape = patch_shape[1:]
145    assert len(patch_shape) == 2
146
147    if isinstance(raw_paths, str):
148        raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi)
149        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
150    elif raw_key is None:
151        assert label_key is None
152        assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple))
153        assert len(raw_paths) == len(label_paths)
154        ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs)
155    else:
156        ds = []
157        n_samples = kwargs.pop("n_samples", None)
158        samples_per_ds = (
159            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
160        )
161        if roi is None:
162            roi = len(raw_paths) * [None]
163        assert len(roi) == len(raw_paths)
164        for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)):
165            rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi)
166            dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs)
167            ds.append(dset)
168        ds = ConcatDataset(*ds)
169    return ds
170
171
172def _get_default_transform(path, key, is_seg_dataset, ndim):
173    if is_seg_dataset and ndim is None:
174        shape = load_data(path, key).shape
175        if len(shape) == 2:
176            ndim = 2
177        else:
178            # heuristics to figure out whether to use default 3d
179            # or default anisotropic augmentations
180            ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3
181    elif is_seg_dataset and ndim is not None:
182        pass
183    else:
184        ndim = 2
185    return get_augmentations(ndim)
186
187
188def default_segmentation_loader(
189    raw_paths,
190    raw_key,
191    label_paths,
192    label_key,
193    batch_size,
194    patch_shape,
195    label_transform=None,
196    label_transform2=None,
197    raw_transform=None,
198    transform=None,
199    dtype=torch.float32,
200    label_dtype=torch.float32,
201    rois=None,
202    n_samples=None,
203    sampler=None,
204    ndim=None,
205    is_seg_dataset=None,
206    with_channels=False,
207    with_label_channels=False,
208    **loader_kwargs,
209):
210    ds = default_segmentation_dataset(
211        raw_paths=raw_paths,
212        raw_key=raw_key,
213        label_paths=label_paths,
214        label_key=label_key,
215        patch_shape=patch_shape,
216        label_transform=label_transform,
217        label_transform2=label_transform2,
218        raw_transform=raw_transform,
219        transform=transform,
220        dtype=dtype,
221        label_dtype=label_dtype,
222        rois=rois,
223        n_samples=n_samples,
224        sampler=sampler,
225        ndim=ndim,
226        is_seg_dataset=is_seg_dataset,
227        with_channels=with_channels,
228        with_label_channels=with_label_channels,
229    )
230    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
231
232
233def default_segmentation_dataset(
234    raw_paths,
235    raw_key,
236    label_paths,
237    label_key,
238    patch_shape,
239    label_transform=None,
240    label_transform2=None,
241    raw_transform=None,
242    transform=None,
243    dtype=torch.float32,
244    label_dtype=torch.float32,
245    rois=None,
246    n_samples=None,
247    sampler=None,
248    ndim=None,
249    is_seg_dataset=None,
250    with_channels=False,
251    with_label_channels=False,
252):
253    check_paths(raw_paths, label_paths)
254    if is_seg_dataset is None:
255        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
256
257    # we always use a raw transform in the convenience function
258    if raw_transform is None:
259        raw_transform = get_raw_transform()
260
261    # we always use augmentations in the convenience function
262    if transform is None:
263        transform = _get_default_transform(
264            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
265        )
266
267    if is_seg_dataset:
268        ds = _load_segmentation_dataset(
269            raw_paths,
270            raw_key,
271            label_paths,
272            label_key,
273            patch_shape=patch_shape,
274            raw_transform=raw_transform,
275            label_transform=label_transform,
276            label_transform2=label_transform2,
277            transform=transform,
278            rois=rois,
279            n_samples=n_samples,
280            sampler=sampler,
281            ndim=ndim,
282            dtype=dtype,
283            label_dtype=label_dtype,
284            with_channels=with_channels,
285            with_label_channels=with_label_channels,
286        )
287    else:
288        ds = _load_image_collection_dataset(
289            raw_paths,
290            raw_key,
291            label_paths,
292            label_key,
293            roi=rois,
294            patch_shape=patch_shape,
295            label_transform=label_transform,
296            raw_transform=raw_transform,
297            label_transform2=label_transform2,
298            transform=transform,
299            n_samples=n_samples,
300            sampler=sampler,
301            dtype=dtype,
302            label_dtype=label_dtype,
303        )
304
305    return ds
306
307
308def get_data_loader(dataset: torch.utils.data.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.DataLoader:
309    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **loader_kwargs)
310    # monkey patch shuffle attribute to the loader
311    loader.shuffle = loader_kwargs.get("shuffle", False)
312    return loader
313
314
315#
316# convenience functions for segmentation trainers
317#
318
319
320def default_segmentation_trainer(
321    name,
322    model,
323    train_loader,
324    val_loader,
325    loss=None,
326    metric=None,
327    learning_rate=1e-3,
328    device=None,
329    log_image_interval=100,
330    mixed_precision=True,
331    early_stopping=None,
332    logger=TensorboardLogger,
333    logger_kwargs: Optional[Dict[str, Any]] = None,
334    scheduler_kwargs=DEFAULT_SCHEDULER_KWARGS,
335    optimizer_kwargs={},
336    trainer_class=DefaultTrainer,
337    id_=None,
338    save_root=None,
339    compile_model=None,
340):
341    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs)
342    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
343
344    loss = DiceLoss() if loss is None else loss
345    metric = DiceLoss() if metric is None else metric
346
347    if device is None:
348        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
349    else:
350        device = torch.device(device)
351
352    # cpu does not support mixed precision training
353    if device.type == "cpu":
354        mixed_precision = False
355
356    trainer = trainer_class(
357        name=name,
358        model=model,
359        train_loader=train_loader,
360        val_loader=val_loader,
361        loss=loss,
362        metric=metric,
363        optimizer=optimizer,
364        device=device,
365        lr_scheduler=scheduler,
366        mixed_precision=mixed_precision,
367        early_stopping=early_stopping,
368        log_image_interval=log_image_interval,
369        logger=logger,
370        logger_kwargs=logger_kwargs,
371        id_=id_,
372        save_root=save_root,
373        compile_model=compile_model,
374    )
375    return trainer
DEFAULT_SCHEDULER_KWARGS = {'mode': 'min', 'factor': 0.5, 'patience': 5}
def samples_to_datasets(n_samples, raw_paths, raw_key, split='uniform'):
27def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"):
28    assert split in ("balanced", "uniform")
29    n_datasets = len(raw_paths)
30    if split == "uniform":
31        # even distribution of samples to datasets
32        samples_per_ds = n_samples // n_datasets
33        divider = n_samples % n_datasets
34        return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)]
35    else:
36        # distribution of samples to dataset based on the dataset lens
37        raise NotImplementedError
def check_paths(raw_paths, label_paths):
40def check_paths(raw_paths, label_paths):
41    if not isinstance(raw_paths, type(label_paths)):
42        raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}")
43
44    def _check_path(path):
45        if isinstance(path, str):
46            if not os.path.exists(path):
47                raise ValueError(f"Could not find path {path}")
48        else:
49            # check for single path or multiple paths (for same volume - supports multi-modal inputs)
50            for per_path in path:
51                if not os.path.exists(per_path):
52                    raise ValueError(f"Could not find path {per_path}")
53
54    if isinstance(raw_paths, str):
55        _check_path(raw_paths)
56        _check_path(label_paths)
57    else:
58        if len(raw_paths) != len(label_paths):
59            raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}")
60        for rp, lp in zip(raw_paths, label_paths):
61            _check_path(rp)
62            _check_path(lp)
def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
65def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
66    """ Check if we can load the data as SegmentationDataset
67    """
68
69    def _can_open(path, key):
70        try:
71            load_data(path, key)
72            return True
73        except Exception:
74            return False
75
76    if isinstance(raw_paths, str):
77        can_open_raw = _can_open(raw_paths, raw_key)
78        can_open_label = _can_open(label_paths, label_key)
79    else:
80        can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths]
81        if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw):
82            raise ValueError("Inconsistent raw data")
83        can_open_raw = can_open_raw[0]
84
85        can_open_label = [_can_open(lp, label_key) for lp in label_paths]
86        if not can_open_label.count(can_open_label[0]) == len(can_open_label):
87            raise ValueError("Inconsistent label data")
88        can_open_label = can_open_label[0]
89
90    if can_open_raw != can_open_label:
91        raise ValueError("Inconsistent raw and label data")
92
93    return can_open_raw

Check if we can load the data as SegmentationDataset

def default_segmentation_loader( raw_paths, raw_key, label_paths, label_key, batch_size, patch_shape, label_transform=None, label_transform2=None, raw_transform=None, transform=None, dtype=torch.float32, label_dtype=torch.float32, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, with_label_channels=False, **loader_kwargs):
189def default_segmentation_loader(
190    raw_paths,
191    raw_key,
192    label_paths,
193    label_key,
194    batch_size,
195    patch_shape,
196    label_transform=None,
197    label_transform2=None,
198    raw_transform=None,
199    transform=None,
200    dtype=torch.float32,
201    label_dtype=torch.float32,
202    rois=None,
203    n_samples=None,
204    sampler=None,
205    ndim=None,
206    is_seg_dataset=None,
207    with_channels=False,
208    with_label_channels=False,
209    **loader_kwargs,
210):
211    ds = default_segmentation_dataset(
212        raw_paths=raw_paths,
213        raw_key=raw_key,
214        label_paths=label_paths,
215        label_key=label_key,
216        patch_shape=patch_shape,
217        label_transform=label_transform,
218        label_transform2=label_transform2,
219        raw_transform=raw_transform,
220        transform=transform,
221        dtype=dtype,
222        label_dtype=label_dtype,
223        rois=rois,
224        n_samples=n_samples,
225        sampler=sampler,
226        ndim=ndim,
227        is_seg_dataset=is_seg_dataset,
228        with_channels=with_channels,
229        with_label_channels=with_label_channels,
230    )
231    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def default_segmentation_dataset( raw_paths, raw_key, label_paths, label_key, patch_shape, label_transform=None, label_transform2=None, raw_transform=None, transform=None, dtype=torch.float32, label_dtype=torch.float32, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, with_label_channels=False):
234def default_segmentation_dataset(
235    raw_paths,
236    raw_key,
237    label_paths,
238    label_key,
239    patch_shape,
240    label_transform=None,
241    label_transform2=None,
242    raw_transform=None,
243    transform=None,
244    dtype=torch.float32,
245    label_dtype=torch.float32,
246    rois=None,
247    n_samples=None,
248    sampler=None,
249    ndim=None,
250    is_seg_dataset=None,
251    with_channels=False,
252    with_label_channels=False,
253):
254    check_paths(raw_paths, label_paths)
255    if is_seg_dataset is None:
256        is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key)
257
258    # we always use a raw transform in the convenience function
259    if raw_transform is None:
260        raw_transform = get_raw_transform()
261
262    # we always use augmentations in the convenience function
263    if transform is None:
264        transform = _get_default_transform(
265            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
266        )
267
268    if is_seg_dataset:
269        ds = _load_segmentation_dataset(
270            raw_paths,
271            raw_key,
272            label_paths,
273            label_key,
274            patch_shape=patch_shape,
275            raw_transform=raw_transform,
276            label_transform=label_transform,
277            label_transform2=label_transform2,
278            transform=transform,
279            rois=rois,
280            n_samples=n_samples,
281            sampler=sampler,
282            ndim=ndim,
283            dtype=dtype,
284            label_dtype=label_dtype,
285            with_channels=with_channels,
286            with_label_channels=with_label_channels,
287        )
288    else:
289        ds = _load_image_collection_dataset(
290            raw_paths,
291            raw_key,
292            label_paths,
293            label_key,
294            roi=rois,
295            patch_shape=patch_shape,
296            label_transform=label_transform,
297            raw_transform=raw_transform,
298            label_transform2=label_transform2,
299            transform=transform,
300            n_samples=n_samples,
301            sampler=sampler,
302            dtype=dtype,
303            label_dtype=label_dtype,
304        )
305
306    return ds
def get_data_loader( dataset: torch.utils.data.dataset.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
309def get_data_loader(dataset: torch.utils.data.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.DataLoader:
310    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **loader_kwargs)
311    # monkey patch shuffle attribute to the loader
312    loader.shuffle = loader_kwargs.get("shuffle", False)
313    return loader
def default_segmentation_trainer( name, model, train_loader, val_loader, loss=None, metric=None, learning_rate=0.001, device=None, log_image_interval=100, mixed_precision=True, early_stopping=None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs={'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs={}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_=None, save_root=None, compile_model=None):
321def default_segmentation_trainer(
322    name,
323    model,
324    train_loader,
325    val_loader,
326    loss=None,
327    metric=None,
328    learning_rate=1e-3,
329    device=None,
330    log_image_interval=100,
331    mixed_precision=True,
332    early_stopping=None,
333    logger=TensorboardLogger,
334    logger_kwargs: Optional[Dict[str, Any]] = None,
335    scheduler_kwargs=DEFAULT_SCHEDULER_KWARGS,
336    optimizer_kwargs={},
337    trainer_class=DefaultTrainer,
338    id_=None,
339    save_root=None,
340    compile_model=None,
341):
342    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs)
343    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs)
344
345    loss = DiceLoss() if loss is None else loss
346    metric = DiceLoss() if metric is None else metric
347
348    if device is None:
349        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
350    else:
351        device = torch.device(device)
352
353    # cpu does not support mixed precision training
354    if device.type == "cpu":
355        mixed_precision = False
356
357    trainer = trainer_class(
358        name=name,
359        model=model,
360        train_loader=train_loader,
361        val_loader=val_loader,
362        loss=loss,
363        metric=metric,
364        optimizer=optimizer,
365        device=device,
366        lr_scheduler=scheduler,
367        mixed_precision=mixed_precision,
368        early_stopping=early_stopping,
369        log_image_interval=log_image_interval,
370        logger=logger,
371        logger_kwargs=logger_kwargs,
372        id_=id_,
373        save_root=save_root,
374        compile_model=compile_model,
375    )
376    return trainer