torch_em.shallow2deep.pseudolabel_training
1import os 2from typing import Callable, Dict, Optional, Sequence, Tuple, Union 3 4import torch 5from torch_em.data import ConcatDataset, PseudoLabelDataset 6from torch_em.segmentation import ( 7 get_data_loader, get_raw_transform, is_segmentation_dataset, samples_to_datasets, _get_default_transform 8) 9from .shallow2deep_model import Shallow2DeepModel 10 11 12def check_paths(raw_paths): 13 """@private 14 """ 15 def _check_path(path): 16 if not os.path.exists(path): 17 raise ValueError(f"Could not find path {path}") 18 19 if isinstance(raw_paths, str): 20 _check_path(raw_paths) 21 else: 22 for rp in raw_paths: 23 _check_path(rp) 24 25 26def _load_pseudolabel_dataset(raw_paths, raw_key, **kwargs): 27 rois = kwargs.pop("rois", None) 28 if isinstance(raw_paths, str): 29 if rois is not None: 30 assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) 31 ds = PseudoLabelDataset(raw_paths, raw_key, roi=rois, labeler_device="cpu", **kwargs) 32 else: 33 assert len(raw_paths) > 0 34 if rois is not None: 35 assert len(rois) == len(raw_paths), f"{len(rois)}, {len(raw_paths)}" 36 assert all(isinstance(roi, tuple) for roi in rois) 37 n_samples = kwargs.pop("n_samples", None) 38 39 samples_per_ds = ( 40 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 41 ) 42 ds = [] 43 for i, raw_path in enumerate(raw_paths): 44 roi = None if rois is None else rois[i] 45 dset = PseudoLabelDataset( 46 raw_path, raw_key, roi=roi, labeler_device="cpu", n_samples=samples_per_ds[i], **kwargs 47 ) 48 ds.append(dset) 49 ds = ConcatDataset(*ds) 50 return ds 51 52 53def get_pseudolabel_dataset( 54 raw_paths: Union[str, Sequence[str]], 55 raw_key: Optional[str], 56 checkpoint: str, 57 rf_config: Dict, 58 patch_shape: Tuple[int, ...], 59 raw_transform: Optional[Callable] = None, 60 transform: Optional[Callable] = None, 61 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 62 n_samples: Optional[int] = None, 63 ndim: Optional[int] = None, 64 is_raw_dataset: Optional[bool] = None, 65 pseudo_labeler_device: str = "cpu", 66) -> torch.utils.data.Dataset: 67 """Get a pseudo-label dataset for training from a Shallow2Deep model. 68 69 Args: 70 raw_paths: The raw paths for training the model. May also be a single file. 71 raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif. 72 checkpoint: The checkpoint for the trained Shallow2Deep model. 73 rf_config: The configuration for the random forest used for the Shallow2Deep model. 74 patch_shape: The patch shape for training. 75 raw_transform: The transformation to apply to the raw data. 76 transform: The transformation to implement augmentations. 77 rois: The region of interest for the training data. 78 n_samples: The length of this dataset. 79 ndim: The dimensionality of the dataset. 80 is_raw_dataset: Whether this is a segmentation or image collection dataset. 81 If None, will be derived from the data. 82 pseudo_labeler_device: The device for the pseudo labeling model. 83 84 Returns: 85 The dataset with pseudo-labeler. 86 """ 87 check_paths(raw_paths) 88 if is_raw_dataset is None: 89 is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key) 90 91 # we always use a raw transform in the convenience function 92 if raw_transform is None: 93 raw_transform = get_raw_transform() 94 95 # we always use augmentations in the convenience function 96 if transform is None: 97 transform = _get_default_transform( 98 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim 99 ) 100 101 pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device) 102 if is_raw_dataset: 103 ds = _load_pseudolabel_dataset( 104 raw_paths, raw_key, 105 patch_shape=patch_shape, 106 pseudo_labeler=pseudo_labeler, 107 raw_transform=raw_transform, 108 transform=transform, 109 rois=rois, n_samples=n_samples, ndim=ndim, 110 ) 111 else: 112 raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.") 113 return ds 114 115 116# TODO add options for confidence module and consistency 117def get_pseudolabel_loader( 118 raw_paths: Union[str, Sequence[str]], 119 raw_key: Optional[str], 120 checkpoint: str, 121 rf_config: Dict, 122 batch_size: int, 123 patch_shape: Tuple[int, ...], 124 raw_transform: Optional[Callable] = None, 125 transform: Optional[Callable] = None, 126 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 127 n_samples: Optional[int] = None, 128 ndim: Optional[int] = None, 129 is_raw_dataset: Optional[bool] = None, 130 pseudo_labeler_device: str = "cpu", 131 **loader_kwargs, 132) -> torch.utils.data.DataLoader: 133 """Get a pseudo-label dataloader for training from a Shallow2Deep model. 134 135 Args: 136 raw_paths: The raw paths for training the model. May also be a single file. 137 raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif. 138 checkpoint: The checkpoint for the trained Shallow2Deep model. 139 rf_config: The configuration for the random forest used for the Shallow2Deep model. 140 batch_size: The batch size for the data loader. 141 patch_shape: The patch shape for training. 142 raw_transform: The transformation to apply to the raw data. 143 transform: The transformation to implement augmentations. 144 rois: The region of interest for the training data. 145 n_samples: The length of this dataset. 146 ndim: The dimensionality of the dataset. 147 is_raw_dataset: Whether this is a segmentation or image collection dataset. 148 If None, will be derived from the data. 149 pseudo_labeler_device: The device for the pseudo labeling model. 150 loader_kwargs: Keyword arguments for the data loader. 151 152 Returns: 153 The dataloader with pseudo-labeler. 154 """ 155 ds = get_pseudolabel_dataset( 156 raw_paths=raw_paths, raw_key=raw_key, 157 checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape, 158 raw_transform=raw_transform, transform=transform, rois=rois, 159 n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset, 160 pseudo_labeler_device=pseudo_labeler_device, 161 ) 162 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def
get_pseudolabel_dataset( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], checkpoint: str, rf_config: Dict, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, ndim: Optional[int] = None, is_raw_dataset: Optional[bool] = None, pseudo_labeler_device: str = 'cpu') -> torch.utils.data.dataset.Dataset:
54def get_pseudolabel_dataset( 55 raw_paths: Union[str, Sequence[str]], 56 raw_key: Optional[str], 57 checkpoint: str, 58 rf_config: Dict, 59 patch_shape: Tuple[int, ...], 60 raw_transform: Optional[Callable] = None, 61 transform: Optional[Callable] = None, 62 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 63 n_samples: Optional[int] = None, 64 ndim: Optional[int] = None, 65 is_raw_dataset: Optional[bool] = None, 66 pseudo_labeler_device: str = "cpu", 67) -> torch.utils.data.Dataset: 68 """Get a pseudo-label dataset for training from a Shallow2Deep model. 69 70 Args: 71 raw_paths: The raw paths for training the model. May also be a single file. 72 raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif. 73 checkpoint: The checkpoint for the trained Shallow2Deep model. 74 rf_config: The configuration for the random forest used for the Shallow2Deep model. 75 patch_shape: The patch shape for training. 76 raw_transform: The transformation to apply to the raw data. 77 transform: The transformation to implement augmentations. 78 rois: The region of interest for the training data. 79 n_samples: The length of this dataset. 80 ndim: The dimensionality of the dataset. 81 is_raw_dataset: Whether this is a segmentation or image collection dataset. 82 If None, will be derived from the data. 83 pseudo_labeler_device: The device for the pseudo labeling model. 84 85 Returns: 86 The dataset with pseudo-labeler. 87 """ 88 check_paths(raw_paths) 89 if is_raw_dataset is None: 90 is_raw_dataset = is_segmentation_dataset(raw_paths, raw_key, raw_paths, raw_key) 91 92 # we always use a raw transform in the convenience function 93 if raw_transform is None: 94 raw_transform = get_raw_transform() 95 96 # we always use augmentations in the convenience function 97 if transform is None: 98 transform = _get_default_transform( 99 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_raw_dataset, ndim 100 ) 101 102 pseudo_labeler = Shallow2DeepModel(checkpoint, rf_config, pseudo_labeler_device) 103 if is_raw_dataset: 104 ds = _load_pseudolabel_dataset( 105 raw_paths, raw_key, 106 patch_shape=patch_shape, 107 pseudo_labeler=pseudo_labeler, 108 raw_transform=raw_transform, 109 transform=transform, 110 rois=rois, n_samples=n_samples, ndim=ndim, 111 ) 112 else: 113 raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.") 114 return ds
Get a pseudo-label dataset for training from a Shallow2Deep model.
Arguments:
- raw_paths: The raw paths for training the model. May also be a single file.
- raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
- checkpoint: The checkpoint for the trained Shallow2Deep model.
- rf_config: The configuration for the random forest used for the Shallow2Deep model.
- patch_shape: The patch shape for training.
- raw_transform: The transformation to apply to the raw data.
- transform: The transformation to implement augmentations.
- rois: The region of interest for the training data.
- n_samples: The length of this dataset.
- ndim: The dimensionality of the dataset.
- is_raw_dataset: Whether this is a segmentation or image collection dataset. If None, will be derived from the data.
- pseudo_labeler_device: The device for the pseudo labeling model.
Returns:
The dataset with pseudo-labeler.
def
get_pseudolabel_loader( raw_paths: Union[str, Sequence[str]], raw_key: Optional[str], checkpoint: str, rf_config: Dict, batch_size: int, patch_shape: Tuple[int, ...], raw_transform: Optional[Callable] = None, transform: Optional[Callable] = None, rois: Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]], NoneType] = None, n_samples: Optional[int] = None, ndim: Optional[int] = None, is_raw_dataset: Optional[bool] = None, pseudo_labeler_device: str = 'cpu', **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
118def get_pseudolabel_loader( 119 raw_paths: Union[str, Sequence[str]], 120 raw_key: Optional[str], 121 checkpoint: str, 122 rf_config: Dict, 123 batch_size: int, 124 patch_shape: Tuple[int, ...], 125 raw_transform: Optional[Callable] = None, 126 transform: Optional[Callable] = None, 127 rois: Optional[Union[Tuple[slice, ...], Sequence[Tuple[slice, ...]]]] = None, 128 n_samples: Optional[int] = None, 129 ndim: Optional[int] = None, 130 is_raw_dataset: Optional[bool] = None, 131 pseudo_labeler_device: str = "cpu", 132 **loader_kwargs, 133) -> torch.utils.data.DataLoader: 134 """Get a pseudo-label dataloader for training from a Shallow2Deep model. 135 136 Args: 137 raw_paths: The raw paths for training the model. May also be a single file. 138 raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif. 139 checkpoint: The checkpoint for the trained Shallow2Deep model. 140 rf_config: The configuration for the random forest used for the Shallow2Deep model. 141 batch_size: The batch size for the data loader. 142 patch_shape: The patch shape for training. 143 raw_transform: The transformation to apply to the raw data. 144 transform: The transformation to implement augmentations. 145 rois: The region of interest for the training data. 146 n_samples: The length of this dataset. 147 ndim: The dimensionality of the dataset. 148 is_raw_dataset: Whether this is a segmentation or image collection dataset. 149 If None, will be derived from the data. 150 pseudo_labeler_device: The device for the pseudo labeling model. 151 loader_kwargs: Keyword arguments for the data loader. 152 153 Returns: 154 The dataloader with pseudo-labeler. 155 """ 156 ds = get_pseudolabel_dataset( 157 raw_paths=raw_paths, raw_key=raw_key, 158 checkpoint=checkpoint, rf_config=rf_config, patch_shape=patch_shape, 159 raw_transform=raw_transform, transform=transform, rois=rois, 160 n_samples=n_samples, ndim=ndim, is_raw_dataset=is_raw_dataset, 161 pseudo_labeler_device=pseudo_labeler_device, 162 ) 163 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get a pseudo-label dataloader for training from a Shallow2Deep model.
Arguments:
- raw_paths: The raw paths for training the model. May also be a single file.
- raw_key: The internal dataset name for the raw data. Set to None for a regular image file like tif.
- checkpoint: The checkpoint for the trained Shallow2Deep model.
- rf_config: The configuration for the random forest used for the Shallow2Deep model.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape for training.
- raw_transform: The transformation to apply to the raw data.
- transform: The transformation to implement augmentations.
- rois: The region of interest for the training data.
- n_samples: The length of this dataset.
- ndim: The dimensionality of the dataset.
- is_raw_dataset: Whether this is a segmentation or image collection dataset. If None, will be derived from the data.
- pseudo_labeler_device: The device for the pseudo labeling model.
- loader_kwargs: Keyword arguments for the data loader.
Returns:
The dataloader with pseudo-labeler.