torch_em.shallow2deep.pseudolabel_training
1import os 2from torch_em.data import ConcatDataset, PseudoLabelDataset 3from torch_em.segmentation import (get_data_loader, get_raw_transform, 4 is_segmentation_dataset, 5 samples_to_datasets, _get_default_transform) 6from .shallow2deep_model import Shallow2DeepModel 7 8 9def check_paths(raw_paths): 10 def _check_path(path): 11 if not os.path.exists(path): 12 raise ValueError(f"Could not find path {path}") 13 14 if isinstance(raw_paths, str): 15 _check_path(raw_paths) 16 else: 17 for rp in raw_paths: 18 _check_path(rp) 19 20 21def _load_pseudolabel_dataset(raw_paths, raw_key, **kwargs): 22 rois = kwargs.pop("rois", None) 23 if isinstance(raw_paths, str): 24 if rois is not None: 25 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 26 ds = PseudoLabelDataset(raw_paths, raw_key, roi=rois, labeler_device="cpu", **kwargs) 27 else: 28 assert len(raw_paths) > 0 29 if rois is not None: 30 assert len(rois) == len(raw_paths), f"{len(rois)}, {len(raw_paths)}" 31 assert all(isinstance(roi, tuple) for roi in rois) 32 n_samples = kwargs.pop("n_samples", None) 33 34 samples_per_ds = ( 35 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 36 ) 37 ds = [] 38 for i, raw_path in enumerate(raw_paths): 39 roi = None if rois is None else rois[i] 40 dset = PseudoLabelDataset( 41 raw_path, raw_key, roi=roi, labeler_device="cpu", n_samples=samples_per_ds[i], **kwargs 42 ) 43 ds.append(dset) 44 ds = ConcatDataset(*ds) 45 return ds 46 47 48def get_pseudolabel_dataset( 49 raw_paths, 50 raw_key, 51 checkpoint, 52 rf_config, 53 patch_shape, 54 raw_transform=None, 55 transform=None, 56 rois=None, 57 n_samples=None, 58 ndim=None, 59 is_raw_dataset=None, 60 pseudo_labeler_device="cpu", 61): 62 check_paths(raw_paths) 63 if is_raw_dataset is None: 64 is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key) 65 66 # we always use a raw transform in the convenience function 67 if raw_transform is None: 68 raw_transform = get_raw_transform() 69 70 # we always use augmentations in the convenience function 71 if transform is None: 72 transform = _get_default_transform( 73 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim 74 ) 75 76 pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device) 77 if is_raw_dataset: 78 ds = _load_pseudolabel_dataset( 79 raw_paths, raw_key, 80 patch_shape=patch_shape, 81 pseudo_labeler=pseudo_labeler, 82 raw_transform=raw_transform, 83 transform=transform, 84 rois=rois, n_samples=n_samples, ndim=ndim, 85 ) 86 else: 87 raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.") 88 return ds 89 90 91# TODO add options for confidence module and consistency 92def get_pseudolabel_loader( 93 raw_paths, 94 raw_key, 95 checkpoint, 96 rf_config, 97 batch_size, 98 patch_shape, 99 raw_transform=None, 100 transform=None, 101 rois=None, 102 n_samples=None, 103 ndim=None, 104 is_raw_dataset=None, 105 pseudo_labeler_device="cpu", 106 **loader_kwargs, 107): 108 ds = get_pseudolabel_dataset( 109 raw_paths=raw_paths, raw_key=raw_key, 110 checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape, 111 raw_transform=raw_transform, transform=transform, rois=rois, 112 n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset, 113 pseudo_labeler_device=pseudo_labeler_device, 114 ) 115 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def
check_paths(raw_paths):
def
get_pseudolabel_dataset( raw_paths, raw_key, checkpoint, rf_config, patch_shape, raw_transform=None, transform=None, rois=None, n_samples=None, ndim=None, is_raw_dataset=None, pseudo_labeler_device='cpu'):
49def get_pseudolabel_dataset( 50 raw_paths, 51 raw_key, 52 checkpoint, 53 rf_config, 54 patch_shape, 55 raw_transform=None, 56 transform=None, 57 rois=None, 58 n_samples=None, 59 ndim=None, 60 is_raw_dataset=None, 61 pseudo_labeler_device="cpu", 62): 63 check_paths(raw_paths) 64 if is_raw_dataset is None: 65 is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key) 66 67 # we always use a raw transform in the convenience function 68 if raw_transform is None: 69 raw_transform = get_raw_transform() 70 71 # we always use augmentations in the convenience function 72 if transform is None: 73 transform = _get_default_transform( 74 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim 75 ) 76 77 pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device) 78 if is_raw_dataset: 79 ds = _load_pseudolabel_dataset( 80 raw_paths, raw_key, 81 patch_shape=patch_shape, 82 pseudo_labeler=pseudo_labeler, 83 raw_transform=raw_transform, 84 transform=transform, 85 rois=rois, n_samples=n_samples, ndim=ndim, 86 ) 87 else: 88 raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.") 89 return ds
def
get_pseudolabel_loader( raw_paths, raw_key, checkpoint, rf_config, batch_size, patch_shape, raw_transform=None, transform=None, rois=None, n_samples=None, ndim=None, is_raw_dataset=None, pseudo_labeler_device='cpu', **loader_kwargs):
93def get_pseudolabel_loader( 94 raw_paths, 95 raw_key, 96 checkpoint, 97 rf_config, 98 batch_size, 99 patch_shape, 100 raw_transform=None, 101 transform=None, 102 rois=None, 103 n_samples=None, 104 ndim=None, 105 is_raw_dataset=None, 106 pseudo_labeler_device="cpu", 107 **loader_kwargs, 108): 109 ds = get_pseudolabel_dataset( 110 raw_paths=raw_paths, raw_key=raw_key, 111 checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape, 112 raw_transform=raw_transform, transform=transform, rois=rois, 113 n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset, 114 pseudo_labeler_device=pseudo_labeler_device, 115 ) 116 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)