torch_em.segmentation
1import os 2from glob import glob 3from typing import Any, Dict, Optional 4 5import torch 6import torch.utils.data 7 8from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 9from .loss import DiceLoss 10from .trainer import DefaultTrainer 11from .trainer.tensorboard_logger import TensorboardLogger 12from .transform import get_augmentations, get_raw_transform 13from .util import load_data 14 15 16# TODO add a heuristic to estimate this from the number of epochs 17DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5} 18 19 20# 21# convenience functions for segmentation loaders 22# 23 24# TODO implement balanced and make it the default 25# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"): 26def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 27 assert split in ("balanced", "uniform") 28 n_datasets = len(raw_paths) 29 if split == "uniform": 30 # even distribution of samples to datasets 31 samples_per_ds = n_samples // n_datasets 32 divider = n_samples % n_datasets 33 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 34 else: 35 # distribution of samples to dataset based on the dataset lens 36 raise NotImplementedError 37 38 39def check_paths(raw_paths, label_paths): 40 if not isinstance(raw_paths, type(label_paths)): 41 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 42 43 def _check_path(path): 44 if isinstance(path, str): 45 if not os.path.exists(path): 46 raise ValueError(f"Could not find path {path}") 47 else: 48 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 49 for per_path in path: 50 if not os.path.exists(per_path): 51 raise ValueError(f"Could not find path {per_path}") 52 53 if isinstance(raw_paths, str): 54 _check_path(raw_paths) 55 _check_path(label_paths) 56 else: 57 if len(raw_paths) != len(label_paths): 58 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 59 for rp, lp in zip(raw_paths, label_paths): 60 _check_path(rp) 61 _check_path(lp) 62 63 64def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 65 """ Check if we can load the data as SegmentationDataset 66 """ 67 68 def _can_open(path, key): 69 try: 70 load_data(path, key) 71 return True 72 except Exception: 73 return False 74 75 if isinstance(raw_paths, str): 76 can_open_raw = _can_open(raw_paths, raw_key) 77 can_open_label = _can_open(label_paths, label_key) 78 else: 79 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 80 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 81 raise ValueError("Inconsistent raw data") 82 can_open_raw = can_open_raw[0] 83 84 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 85 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 86 raise ValueError("Inconsistent label data") 87 can_open_label = can_open_label[0] 88 89 if can_open_raw != can_open_label: 90 raise ValueError("Inconsistent raw and label data") 91 92 return can_open_raw 93 94 95def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs): 96 rois = kwargs.pop("rois", None) 97 if isinstance(raw_paths, str): 98 if rois is not None: 99 assert isinstance(rois, (tuple, slice)) 100 if isinstance(rois, tuple): 101 assert all(isinstance(roi, slice) for roi in rois) 102 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) 103 else: 104 assert len(raw_paths) > 0 105 if rois is not None: 106 assert len(rois) == len(label_paths) 107 assert all(isinstance(roi, tuple) for roi in rois), f"{rois}" 108 n_samples = kwargs.pop("n_samples", None) 109 110 samples_per_ds = ( 111 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 112 ) 113 ds = [] 114 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 115 roi = None if rois is None else rois[i] 116 dset = SegmentationDataset( 117 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs 118 ) 119 ds.append(dset) 120 ds = ConcatDataset(*ds) 121 return ds 122 123 124def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs): 125 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 126 rpath = glob(os.path.join(rpath, rkey)) 127 rpath.sort() 128 if len(rpath) == 0: 129 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 130 lpath = glob(os.path.join(lpath, lkey)) 131 lpath.sort() 132 if len(rpath) != len(lpath): 133 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 134 135 if this_roi is not None: 136 rpath, lpath = rpath[roi], lpath[roi] 137 138 return rpath, lpath 139 140 patch_shape = kwargs.pop("patch_shape") 141 if len(patch_shape) == 3: 142 if patch_shape[0] != 1: 143 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 144 patch_shape = patch_shape[1:] 145 assert len(patch_shape) == 2 146 147 if isinstance(raw_paths, str): 148 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 149 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 150 elif raw_key is None: 151 assert label_key is None 152 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 153 assert len(raw_paths) == len(label_paths) 154 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 155 else: 156 ds = [] 157 n_samples = kwargs.pop("n_samples", None) 158 samples_per_ds = ( 159 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 160 ) 161 if roi is None: 162 roi = len(raw_paths) * [None] 163 assert len(roi) == len(raw_paths) 164 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 165 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 166 dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs) 167 ds.append(dset) 168 ds = ConcatDataset(*ds) 169 return ds 170 171 172def _get_default_transform(path, key, is_seg_dataset, ndim): 173 if is_seg_dataset and ndim is None: 174 shape = load_data(path, key).shape 175 if len(shape) == 2: 176 ndim = 2 177 else: 178 # heuristics to figure out whether to use default 3d 179 # or default anisotropic augmentations 180 ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 181 elif is_seg_dataset and ndim is not None: 182 pass 183 else: 184 ndim = 2 185 return get_augmentations(ndim) 186 187 188def default_segmentation_loader( 189 raw_paths, 190 raw_key, 191 label_paths, 192 label_key, 193 batch_size, 194 patch_shape, 195 label_transform=None, 196 label_transform2=None, 197 raw_transform=None, 198 transform=None, 199 dtype=torch.float32, 200 label_dtype=torch.float32, 201 rois=None, 202 n_samples=None, 203 sampler=None, 204 ndim=None, 205 is_seg_dataset=None, 206 with_channels=False, 207 with_label_channels=False, 208 **loader_kwargs, 209): 210 ds = default_segmentation_dataset( 211 raw_paths=raw_paths, 212 raw_key=raw_key, 213 label_paths=label_paths, 214 label_key=label_key, 215 patch_shape=patch_shape, 216 label_transform=label_transform, 217 label_transform2=label_transform2, 218 raw_transform=raw_transform, 219 transform=transform, 220 dtype=dtype, 221 label_dtype=label_dtype, 222 rois=rois, 223 n_samples=n_samples, 224 sampler=sampler, 225 ndim=ndim, 226 is_seg_dataset=is_seg_dataset, 227 with_channels=with_channels, 228 with_label_channels=with_label_channels, 229 ) 230 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) 231 232 233def default_segmentation_dataset( 234 raw_paths, 235 raw_key, 236 label_paths, 237 label_key, 238 patch_shape, 239 label_transform=None, 240 label_transform2=None, 241 raw_transform=None, 242 transform=None, 243 dtype=torch.float32, 244 label_dtype=torch.float32, 245 rois=None, 246 n_samples=None, 247 sampler=None, 248 ndim=None, 249 is_seg_dataset=None, 250 with_channels=False, 251 with_label_channels=False, 252): 253 check_paths(raw_paths, label_paths) 254 if is_seg_dataset is None: 255 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 256 257 # we always use a raw transform in the convenience function 258 if raw_transform is None: 259 raw_transform = get_raw_transform() 260 261 # we always use augmentations in the convenience function 262 if transform is None: 263 transform = _get_default_transform( 264 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 265 ) 266 267 if is_seg_dataset: 268 ds = _load_segmentation_dataset( 269 raw_paths, 270 raw_key, 271 label_paths, 272 label_key, 273 patch_shape=patch_shape, 274 raw_transform=raw_transform, 275 label_transform=label_transform, 276 label_transform2=label_transform2, 277 transform=transform, 278 rois=rois, 279 n_samples=n_samples, 280 sampler=sampler, 281 ndim=ndim, 282 dtype=dtype, 283 label_dtype=label_dtype, 284 with_channels=with_channels, 285 with_label_channels=with_label_channels, 286 ) 287 else: 288 ds = _load_image_collection_dataset( 289 raw_paths, 290 raw_key, 291 label_paths, 292 label_key, 293 roi=rois, 294 patch_shape=patch_shape, 295 label_transform=label_transform, 296 raw_transform=raw_transform, 297 label_transform2=label_transform2, 298 transform=transform, 299 n_samples=n_samples, 300 sampler=sampler, 301 dtype=dtype, 302 label_dtype=label_dtype, 303 ) 304 305 return ds 306 307 308def get_data_loader(dataset: torch.utils.data.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.DataLoader: 309 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **loader_kwargs) 310 # monkey patch shuffle attribute to the loader 311 loader.shuffle = loader_kwargs.get("shuffle", False) 312 return loader 313 314 315# 316# convenience functions for segmentation trainers 317# 318 319 320def default_segmentation_trainer( 321 name, 322 model, 323 train_loader, 324 val_loader, 325 loss=None, 326 metric=None, 327 learning_rate=1e-3, 328 device=None, 329 log_image_interval=100, 330 mixed_precision=True, 331 early_stopping=None, 332 logger=TensorboardLogger, 333 logger_kwargs: Optional[Dict[str, Any]] = None, 334 scheduler_kwargs=DEFAULT_SCHEDULER_KWARGS, 335 optimizer_kwargs={}, 336 trainer_class=DefaultTrainer, 337 id_=None, 338 save_root=None, 339 compile_model=None, 340): 341 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs) 342 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 343 344 loss = DiceLoss() if loss is None else loss 345 metric = DiceLoss() if metric is None else metric 346 347 if device is None: 348 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 349 else: 350 device = torch.device(device) 351 352 # cpu does not support mixed precision training 353 if device.type == "cpu": 354 mixed_precision = False 355 356 trainer = trainer_class( 357 name=name, 358 model=model, 359 train_loader=train_loader, 360 val_loader=val_loader, 361 loss=loss, 362 metric=metric, 363 optimizer=optimizer, 364 device=device, 365 lr_scheduler=scheduler, 366 mixed_precision=mixed_precision, 367 early_stopping=early_stopping, 368 log_image_interval=log_image_interval, 369 logger=logger, 370 logger_kwargs=logger_kwargs, 371 id_=id_, 372 save_root=save_root, 373 compile_model=compile_model, 374 ) 375 return trainer
DEFAULT_SCHEDULER_KWARGS =
{'mode': 'min', 'factor': 0.5, 'patience': 5}
def
samples_to_datasets(n_samples, raw_paths, raw_key, split='uniform'):
27def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 28 assert split in ("balanced", "uniform") 29 n_datasets = len(raw_paths) 30 if split == "uniform": 31 # even distribution of samples to datasets 32 samples_per_ds = n_samples // n_datasets 33 divider = n_samples % n_datasets 34 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 35 else: 36 # distribution of samples to dataset based on the dataset lens 37 raise NotImplementedError
def
check_paths(raw_paths, label_paths):
40def check_paths(raw_paths, label_paths): 41 if not isinstance(raw_paths, type(label_paths)): 42 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 43 44 def _check_path(path): 45 if isinstance(path, str): 46 if not os.path.exists(path): 47 raise ValueError(f"Could not find path {path}") 48 else: 49 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 50 for per_path in path: 51 if not os.path.exists(per_path): 52 raise ValueError(f"Could not find path {per_path}") 53 54 if isinstance(raw_paths, str): 55 _check_path(raw_paths) 56 _check_path(label_paths) 57 else: 58 if len(raw_paths) != len(label_paths): 59 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 60 for rp, lp in zip(raw_paths, label_paths): 61 _check_path(rp) 62 _check_path(lp)
def
is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key):
65def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 66 """ Check if we can load the data as SegmentationDataset 67 """ 68 69 def _can_open(path, key): 70 try: 71 load_data(path, key) 72 return True 73 except Exception: 74 return False 75 76 if isinstance(raw_paths, str): 77 can_open_raw = _can_open(raw_paths, raw_key) 78 can_open_label = _can_open(label_paths, label_key) 79 else: 80 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 81 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 82 raise ValueError("Inconsistent raw data") 83 can_open_raw = can_open_raw[0] 84 85 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 86 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 87 raise ValueError("Inconsistent label data") 88 can_open_label = can_open_label[0] 89 90 if can_open_raw != can_open_label: 91 raise ValueError("Inconsistent raw and label data") 92 93 return can_open_raw
Check if we can load the data as SegmentationDataset
def
default_segmentation_loader( raw_paths, raw_key, label_paths, label_key, batch_size, patch_shape, label_transform=None, label_transform2=None, raw_transform=None, transform=None, dtype=torch.float32, label_dtype=torch.float32, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, with_label_channels=False, **loader_kwargs):
189def default_segmentation_loader( 190 raw_paths, 191 raw_key, 192 label_paths, 193 label_key, 194 batch_size, 195 patch_shape, 196 label_transform=None, 197 label_transform2=None, 198 raw_transform=None, 199 transform=None, 200 dtype=torch.float32, 201 label_dtype=torch.float32, 202 rois=None, 203 n_samples=None, 204 sampler=None, 205 ndim=None, 206 is_seg_dataset=None, 207 with_channels=False, 208 with_label_channels=False, 209 **loader_kwargs, 210): 211 ds = default_segmentation_dataset( 212 raw_paths=raw_paths, 213 raw_key=raw_key, 214 label_paths=label_paths, 215 label_key=label_key, 216 patch_shape=patch_shape, 217 label_transform=label_transform, 218 label_transform2=label_transform2, 219 raw_transform=raw_transform, 220 transform=transform, 221 dtype=dtype, 222 label_dtype=label_dtype, 223 rois=rois, 224 n_samples=n_samples, 225 sampler=sampler, 226 ndim=ndim, 227 is_seg_dataset=is_seg_dataset, 228 with_channels=with_channels, 229 with_label_channels=with_label_channels, 230 ) 231 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
def
default_segmentation_dataset( raw_paths, raw_key, label_paths, label_key, patch_shape, label_transform=None, label_transform2=None, raw_transform=None, transform=None, dtype=torch.float32, label_dtype=torch.float32, rois=None, n_samples=None, sampler=None, ndim=None, is_seg_dataset=None, with_channels=False, with_label_channels=False):
234def default_segmentation_dataset( 235 raw_paths, 236 raw_key, 237 label_paths, 238 label_key, 239 patch_shape, 240 label_transform=None, 241 label_transform2=None, 242 raw_transform=None, 243 transform=None, 244 dtype=torch.float32, 245 label_dtype=torch.float32, 246 rois=None, 247 n_samples=None, 248 sampler=None, 249 ndim=None, 250 is_seg_dataset=None, 251 with_channels=False, 252 with_label_channels=False, 253): 254 check_paths(raw_paths, label_paths) 255 if is_seg_dataset is None: 256 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 257 258 # we always use a raw transform in the convenience function 259 if raw_transform is None: 260 raw_transform = get_raw_transform() 261 262 # we always use augmentations in the convenience function 263 if transform is None: 264 transform = _get_default_transform( 265 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 266 ) 267 268 if is_seg_dataset: 269 ds = _load_segmentation_dataset( 270 raw_paths, 271 raw_key, 272 label_paths, 273 label_key, 274 patch_shape=patch_shape, 275 raw_transform=raw_transform, 276 label_transform=label_transform, 277 label_transform2=label_transform2, 278 transform=transform, 279 rois=rois, 280 n_samples=n_samples, 281 sampler=sampler, 282 ndim=ndim, 283 dtype=dtype, 284 label_dtype=label_dtype, 285 with_channels=with_channels, 286 with_label_channels=with_label_channels, 287 ) 288 else: 289 ds = _load_image_collection_dataset( 290 raw_paths, 291 raw_key, 292 label_paths, 293 label_key, 294 roi=rois, 295 patch_shape=patch_shape, 296 label_transform=label_transform, 297 raw_transform=raw_transform, 298 label_transform2=label_transform2, 299 transform=transform, 300 n_samples=n_samples, 301 sampler=sampler, 302 dtype=dtype, 303 label_dtype=label_dtype, 304 ) 305 306 return ds
def
get_data_loader( dataset: torch.utils.data.dataset.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
309def get_data_loader(dataset: torch.utils.data.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.DataLoader: 310 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **loader_kwargs) 311 # monkey patch shuffle attribute to the loader 312 loader.shuffle = loader_kwargs.get("shuffle", False) 313 return loader
def
default_segmentation_trainer( name, model, train_loader, val_loader, loss=None, metric=None, learning_rate=0.001, device=None, log_image_interval=100, mixed_precision=True, early_stopping=None, logger=<class 'torch_em.trainer.tensorboard_logger.TensorboardLogger'>, logger_kwargs: Optional[Dict[str, Any]] = None, scheduler_kwargs={'mode': 'min', 'factor': 0.5, 'patience': 5}, optimizer_kwargs={}, trainer_class=<class 'torch_em.trainer.default_trainer.DefaultTrainer'>, id_=None, save_root=None, compile_model=None):
321def default_segmentation_trainer( 322 name, 323 model, 324 train_loader, 325 val_loader, 326 loss=None, 327 metric=None, 328 learning_rate=1e-3, 329 device=None, 330 log_image_interval=100, 331 mixed_precision=True, 332 early_stopping=None, 333 logger=TensorboardLogger, 334 logger_kwargs: Optional[Dict[str, Any]] = None, 335 scheduler_kwargs=DEFAULT_SCHEDULER_KWARGS, 336 optimizer_kwargs={}, 337 trainer_class=DefaultTrainer, 338 id_=None, 339 save_root=None, 340 compile_model=None, 341): 342 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs) 343 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 344 345 loss = DiceLoss() if loss is None else loss 346 metric = DiceLoss() if metric is None else metric 347 348 if device is None: 349 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 350 else: 351 device = torch.device(device) 352 353 # cpu does not support mixed precision training 354 if device.type == "cpu": 355 mixed_precision = False 356 357 trainer = trainer_class( 358 name=name, 359 model=model, 360 train_loader=train_loader, 361 val_loader=val_loader, 362 loss=loss, 363 metric=metric, 364 optimizer=optimizer, 365 device=device, 366 lr_scheduler=scheduler, 367 mixed_precision=mixed_precision, 368 early_stopping=early_stopping, 369 log_image_interval=log_image_interval, 370 logger=logger, 371 logger_kwargs=logger_kwargs, 372 id_=id_, 373 save_root=save_root, 374 compile_model=compile_model, 375 ) 376 return trainer