torch_em.shallow2deep.pseudolabel_training

  1import os
  2from torch_em.data import ConcatDataset, PseudoLabelDataset
  3from torch_em.segmentation import (get_data_loader, get_raw_transform,
  4                                   is_segmentation_dataset,
  5                                   samples_to_datasets, _get_default_transform)
  6from .shallow2deep_model import Shallow2DeepModel
  7
  8
  9def check_paths(raw_paths):
 10    def _check_path(path):
 11        if not os.path.exists(path):
 12            raise ValueError(f"Could not find path {path}")
 13
 14    if isinstance(raw_paths, str):
 15        _check_path(raw_paths)
 16    else:
 17        for rp in raw_paths:
 18            _check_path(rp)
 19
 20
 21def _load_pseudolabel_dataset(raw_paths, raw_key, **kwargs):
 22    rois = kwargs.pop("rois", None)
 23    if isinstance(raw_paths, str):
 24        if rois is not None:
 25            assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
 26        ds = PseudoLabelDataset(raw_paths, raw_key, roi=rois, labeler_device="cpu", **kwargs)
 27    else:
 28        assert len(raw_paths) > 0
 29        if rois is not None:
 30            assert len(rois) == len(raw_paths), f"{len(rois)}, {len(raw_paths)}"
 31            assert all(isinstance(roi, tuple) for roi in rois)
 32        n_samples = kwargs.pop("n_samples", None)
 33
 34        samples_per_ds = (
 35            [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key)
 36        )
 37        ds = []
 38        for i, raw_path in enumerate(raw_paths):
 39            roi = None if rois is None else rois[i]
 40            dset = PseudoLabelDataset(
 41                raw_path, raw_key, roi=roi, labeler_device="cpu", n_samples=samples_per_ds[i], **kwargs
 42            )
 43            ds.append(dset)
 44        ds = ConcatDataset(*ds)
 45    return ds
 46
 47
 48def get_pseudolabel_dataset(
 49    raw_paths,
 50    raw_key,
 51    checkpoint,
 52    rf_config,
 53    patch_shape,
 54    raw_transform=None,
 55    transform=None,
 56    rois=None,
 57    n_samples=None,
 58    ndim=None,
 59    is_raw_dataset=None,
 60    pseudo_labeler_device="cpu",
 61):
 62    check_paths(raw_paths)
 63    if is_raw_dataset is None:
 64        is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key)
 65
 66    # we always use a raw transform in the convenience function
 67    if raw_transform is None:
 68        raw_transform = get_raw_transform()
 69
 70    # we always use augmentations in the convenience function
 71    if transform is None:
 72        transform = _get_default_transform(
 73            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim
 74        )
 75
 76    pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device)
 77    if is_raw_dataset:
 78        ds = _load_pseudolabel_dataset(
 79            raw_paths, raw_key,
 80            patch_shape=patch_shape,
 81            pseudo_labeler=pseudo_labeler,
 82            raw_transform=raw_transform,
 83            transform=transform,
 84            rois=rois, n_samples=n_samples, ndim=ndim,
 85        )
 86    else:
 87        raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.")
 88    return ds
 89
 90
 91# TODO add options for confidence module and consistency
 92def get_pseudolabel_loader(
 93    raw_paths,
 94    raw_key,
 95    checkpoint,
 96    rf_config,
 97    batch_size,
 98    patch_shape,
 99    raw_transform=None,
100    transform=None,
101    rois=None,
102    n_samples=None,
103    ndim=None,
104    is_raw_dataset=None,
105    pseudo_labeler_device="cpu",
106    **loader_kwargs,
107):
108    ds = get_pseudolabel_dataset(
109        raw_paths=raw_paths, raw_key=raw_key,
110        checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape,
111        raw_transform=raw_transform, transform=transform, rois=rois,
112        n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset,
113        pseudo_labeler_device=pseudo_labeler_device,
114    )
115    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def check_paths(raw_paths):
10def check_paths(raw_paths):
11    def _check_path(path):
12        if not os.path.exists(path):
13            raise ValueError(f"Could not find path {path}")
14
15    if isinstance(raw_paths, str):
16        _check_path(raw_paths)
17    else:
18        for rp in raw_paths:
19            _check_path(rp)
def get_pseudolabel_dataset( raw_paths, raw_key, checkpoint, rf_config, patch_shape, raw_transform=None, transform=None, rois=None, n_samples=None, ndim=None, is_raw_dataset=None, pseudo_labeler_device='cpu'):
49def get_pseudolabel_dataset(
50    raw_paths,
51    raw_key,
52    checkpoint,
53    rf_config,
54    patch_shape,
55    raw_transform=None,
56    transform=None,
57    rois=None,
58    n_samples=None,
59    ndim=None,
60    is_raw_dataset=None,
61    pseudo_labeler_device="cpu",
62):
63    check_paths(raw_paths)
64    if is_raw_dataset is None:
65        is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key)
66
67    # we always use a raw transform in the convenience function
68    if raw_transform is None:
69        raw_transform = get_raw_transform()
70
71    # we always use augmentations in the convenience function
72    if transform is None:
73        transform = _get_default_transform(
74            raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim
75        )
76
77    pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device)
78    if is_raw_dataset:
79        ds = _load_pseudolabel_dataset(
80            raw_paths, raw_key,
81            patch_shape=patch_shape,
82            pseudo_labeler=pseudo_labeler,
83            raw_transform=raw_transform,
84            transform=transform,
85            rois=rois, n_samples=n_samples, ndim=ndim,
86        )
87    else:
88        raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.")
89    return ds
def get_pseudolabel_loader( raw_paths, raw_key, checkpoint, rf_config, batch_size, patch_shape, raw_transform=None, transform=None, rois=None, n_samples=None, ndim=None, is_raw_dataset=None, pseudo_labeler_device='cpu', **loader_kwargs):
 93def get_pseudolabel_loader(
 94    raw_paths,
 95    raw_key,
 96    checkpoint,
 97    rf_config,
 98    batch_size,
 99    patch_shape,
100    raw_transform=None,
101    transform=None,
102    rois=None,
103    n_samples=None,
104    ndim=None,
105    is_raw_dataset=None,
106    pseudo_labeler_device="cpu",
107    **loader_kwargs,
108):
109    ds = get_pseudolabel_dataset(
110        raw_paths=raw_paths, raw_key=raw_key,
111        checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape,
112        raw_transform=raw_transform, transform=transform, rois=rois,
113        n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset,
114        pseudo_labeler_device=pseudo_labeler_device,
115    )
116    return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)