torch_em.segmentation
1import os 2from glob import glob 3from typing import Any, Dict, Optional, Union, Tuple, List, Callable 4 5import torch 6import torch.utils.data 7from torch.utils.data import DataLoader 8 9from .loss import DiceLoss 10from .util import load_data 11from .trainer import DefaultTrainer 12from .trainer.tensorboard_logger import TensorboardLogger 13from .transform import get_augmentations, get_raw_transform 14from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 15 16 17# TODO add a heuristic to estimate this from the number of epochs 18DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5} 19"""@private 20""" 21 22 23# 24# convenience functions for segmentation loaders 25# 26 27# TODO implement balanced and make it the default 28# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"): 29def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 30 """@private 31 """ 32 assert split in ("balanced", "uniform") 33 n_datasets = len(raw_paths) 34 if split == "uniform": 35 # even distribution of samples to datasets 36 samples_per_ds = n_samples // n_datasets 37 divider = n_samples % n_datasets 38 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 39 else: 40 # distribution of samples to dataset based on the dataset lens 41 raise NotImplementedError 42 43 44def check_paths(raw_paths, label_paths): 45 """@private 46 """ 47 if not isinstance(raw_paths, type(label_paths)): 48 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 49 50 def _check_path(path): 51 if isinstance(path, str): 52 if not os.path.exists(path): 53 raise ValueError(f"Could not find path {path}") 54 else: 55 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 56 for per_path in path: 57 if not os.path.exists(per_path): 58 raise ValueError(f"Could not find path {per_path}") 59 60 if isinstance(raw_paths, str): 61 _check_path(raw_paths) 62 _check_path(label_paths) 63 else: 64 if len(raw_paths) != len(label_paths): 65 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 66 for rp, lp in zip(raw_paths, label_paths): 67 _check_path(rp) 68 _check_path(lp) 69 70 71# Check if we can load the data as SegmentationDataset. 72def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 73 """@private 74 """ 75 def _can_open(path, key): 76 try: 77 load_data(path, key) 78 return True 79 except Exception: 80 return False 81 82 if isinstance(raw_paths, str): 83 can_open_raw = _can_open(raw_paths, raw_key) 84 can_open_label = _can_open(label_paths, label_key) 85 else: 86 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 87 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 88 raise ValueError("Inconsistent raw data") 89 can_open_raw = can_open_raw[0] 90 91 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 92 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 93 raise ValueError("Inconsistent label data") 94 can_open_label = can_open_label[0] 95 96 if can_open_raw != can_open_label: 97 raise ValueError("Inconsistent raw and label data") 98 99 return can_open_raw 100 101 102def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs): 103 rois = kwargs.pop("rois", None) 104 if isinstance(raw_paths, str): 105 if rois is not None: 106 assert isinstance(rois, (tuple, slice)) 107 if isinstance(rois, tuple): 108 assert all(isinstance(roi, slice) for roi in rois) 109 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) 110 else: 111 assert len(raw_paths) > 0 112 if rois is not None: 113 assert len(rois) == len(label_paths) 114 assert all(isinstance(roi, tuple) for roi in rois), f"{rois}" 115 n_samples = kwargs.pop("n_samples", None) 116 117 samples_per_ds = ( 118 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 119 ) 120 ds = [] 121 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 122 roi = None if rois is None else rois[i] 123 dset = SegmentationDataset( 124 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs 125 ) 126 ds.append(dset) 127 ds = ConcatDataset(*ds) 128 return ds 129 130 131def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs): 132 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 133 rpath = glob(os.path.join(rpath, rkey)) 134 rpath.sort() 135 if len(rpath) == 0: 136 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 137 138 lpath = glob(os.path.join(lpath, lkey)) 139 lpath.sort() 140 if len(rpath) != len(lpath): 141 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 142 143 if this_roi is not None: 144 rpath, lpath = rpath[roi], lpath[roi] 145 146 return rpath, lpath 147 148 patch_shape = kwargs.pop("patch_shape") 149 if patch_shape is not None: 150 if len(patch_shape) == 3: 151 if patch_shape[0] != 1: 152 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 153 patch_shape = patch_shape[1:] 154 assert len(patch_shape) == 2 155 156 if isinstance(raw_paths, str): 157 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 158 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 159 160 elif raw_key is None: 161 assert label_key is None 162 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 163 assert len(raw_paths) == len(label_paths) 164 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 165 166 else: 167 ds = [] 168 n_samples = kwargs.pop("n_samples", None) 169 samples_per_ds = ( 170 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 171 ) 172 if roi is None: 173 roi = len(raw_paths) * [None] 174 assert len(roi) == len(raw_paths) 175 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 176 print(raw_path, label_path, this_roi) 177 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 178 dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs) 179 ds.append(dset) 180 ds = ConcatDataset(*ds) 181 182 return ds 183 184 185def _get_default_transform(path, key, is_seg_dataset, ndim): 186 if is_seg_dataset and ndim is None: 187 shape = load_data(path, key).shape 188 if len(shape) == 2: 189 ndim = 2 190 else: 191 # heuristics to figure out whether to use default 3d 192 # or default anisotropic augmentations 193 ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 194 195 elif is_seg_dataset and ndim is not None: 196 pass 197 198 else: 199 ndim = 2 200 201 return get_augmentations(ndim) 202 203 204def default_segmentation_loader( 205 raw_paths: Union[List[Any], str, os.PathLike], 206 raw_key: Optional[str], 207 label_paths: Union[List[Any], str, os.PathLike], 208 label_key: Optional[str], 209 batch_size: int, 210 patch_shape: Tuple[int, ...], 211 label_transform: Optional[Callable] = None, 212 label_transform2: Optional[Callable] = None, 213 raw_transform: Optional[Callable] = None, 214 transform: Optional[Callable] = None, 215 dtype: torch.device = torch.float32, 216 label_dtype: torch.device = torch.float32, 217 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 218 n_samples: Optional[int] = None, 219 sampler: Optional[Callable] = None, 220 ndim: Optional[int] = None, 221 is_seg_dataset: Optional[bool] = None, 222 with_channels: bool = False, 223 with_label_channels: bool = False, 224 verify_paths: bool = True, 225 with_padding: bool = True, 226 z_ext: Optional[int] = None, 227 **loader_kwargs, 228) -> torch.utils.data.DataLoader: 229 """Get data loader for training a segmentation network. 230 231 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 232 on the data formats that are supported. 233 234 Args: 235 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 236 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 237 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 238 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 239 batch_size: The batch size for the data loader. 240 patch_shape: The patch shape for the training samples. 241 label_transform: Transformation applied to the label data of a sample, 242 before applying augmentations via `transform`. 243 label_transform2: Transformation applied to the label data of a sample, 244 after applying augmentations via `transform`. 245 raw_transform: Transformation applied to the raw data of a sample, 246 before applying augmentations via `transform`. 247 transform: Transformation applied to both the raw data and label data of a sample. 248 This can be used to implement data augmentations. 249 dtype: The return data type of the raw data. 250 label_dtype: The return data type of the label data. 251 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 252 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 253 sampler: Sampler for rejecting samples according to a defined criterion. 254 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 255 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 256 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 257 If None, the type of dataset will be derived from the data. 258 with_channels: Whether the raw data has channels. 259 with_label_channels: Whether the label data has channels. 260 verify_paths: Whether to verify all paths before creating the dataset. 261 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 262 z_ext: Extra bounding box for loading the data across z. 263 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 264 265 Returns: 266 The torch data loader. 267 """ 268 ds = default_segmentation_dataset( 269 raw_paths=raw_paths, 270 raw_key=raw_key, 271 label_paths=label_paths, 272 label_key=label_key, 273 patch_shape=patch_shape, 274 label_transform=label_transform, 275 label_transform2=label_transform2, 276 raw_transform=raw_transform, 277 transform=transform, 278 dtype=dtype, 279 label_dtype=label_dtype, 280 rois=rois, 281 n_samples=n_samples, 282 sampler=sampler, 283 ndim=ndim, 284 is_seg_dataset=is_seg_dataset, 285 with_channels=with_channels, 286 with_label_channels=with_label_channels, 287 with_padding=with_padding, 288 z_ext=z_ext, 289 verify_paths=verify_paths, 290 ) 291 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) 292 293 294def default_segmentation_dataset( 295 raw_paths: Union[List[Any], str, os.PathLike], 296 raw_key: Optional[str], 297 label_paths: Union[List[Any], str, os.PathLike], 298 label_key: Optional[str], 299 patch_shape: Tuple[int, ...], 300 label_transform: Optional[Callable] = None, 301 label_transform2: Optional[Callable] = None, 302 raw_transform: Optional[Callable] = None, 303 transform: Optional[Callable] = None, 304 dtype: torch.dtype = torch.float32, 305 label_dtype: torch.dtype = torch.float32, 306 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 307 n_samples: Optional[int] = None, 308 sampler: Optional[Callable] = None, 309 ndim: Optional[int] = None, 310 is_seg_dataset: Optional[bool] = None, 311 with_channels: bool = False, 312 with_label_channels: bool = False, 313 verify_paths: bool = True, 314 with_padding: bool = True, 315 z_ext: Optional[int] = None, 316) -> torch.utils.data.Dataset: 317 """Get data set for training a segmentation network. 318 319 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 320 on the data formats that are supported. 321 322 Args: 323 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 324 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 325 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 326 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 327 patch_shape: The patch shape for the training samples. 328 label_transform: Transformation applied to the label data of a sample, 329 before applying augmentations via `transform`. 330 label_transform2: Transformation applied to the label data of a sample, 331 after applying augmentations via `transform`. 332 raw_transform: Transformation applied to the raw data of a sample, 333 before applying augmentations via `transform`. 334 transform: Transformation applied to both the raw data and label data of a sample. 335 This can be used to implement data augmentations. 336 dtype: The return data type of the raw data. 337 label_dtype: The return data type of the label data. 338 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 339 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 340 sampler: Sampler for rejecting samples according to a defined criterion. 341 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 342 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 343 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 344 If None, the type of dataset will be derived from the data. 345 with_channels: Whether the raw data has channels. 346 with_label_channels: Whether the label data has channels. 347 verify_paths: Whether to verify all paths before creating the dataset. 348 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 349 z_ext: Extra bounding box for loading the data across z. 350 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 351 352 Returns: 353 The torch dataset. 354 """ 355 if verify_paths: 356 check_paths(raw_paths, label_paths) 357 358 if is_seg_dataset is None: 359 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 360 361 # We always use a raw transform in the convenience function. 362 if raw_transform is None: 363 raw_transform = get_raw_transform() 364 365 # We always use augmentations in the convenience function. 366 if transform is None: 367 transform = _get_default_transform( 368 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 369 ) 370 371 if is_seg_dataset: 372 ds = _load_segmentation_dataset( 373 raw_paths, 374 raw_key, 375 label_paths, 376 label_key, 377 patch_shape=patch_shape, 378 raw_transform=raw_transform, 379 label_transform=label_transform, 380 label_transform2=label_transform2, 381 transform=transform, 382 rois=rois, 383 n_samples=n_samples, 384 sampler=sampler, 385 ndim=ndim, 386 dtype=dtype, 387 label_dtype=label_dtype, 388 with_channels=with_channels, 389 with_label_channels=with_label_channels, 390 with_padding=with_padding, 391 z_ext=z_ext, 392 ) 393 394 else: 395 ds = _load_image_collection_dataset( 396 raw_paths, 397 raw_key, 398 label_paths, 399 label_key, 400 roi=rois, 401 patch_shape=patch_shape, 402 label_transform=label_transform, 403 raw_transform=raw_transform, 404 label_transform2=label_transform2, 405 transform=transform, 406 n_samples=n_samples, 407 sampler=sampler, 408 dtype=dtype, 409 label_dtype=label_dtype, 410 with_padding=with_padding, 411 ) 412 413 return ds 414 415 416def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader: 417 """@private 418 """ 419 pin_memory = loader_kwargs.pop("pin_memory", True) 420 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs) 421 # monkey patch shuffle attribute to the loader 422 loader.shuffle = loader_kwargs.get("shuffle", False) 423 return loader 424 425 426# 427# convenience functions for segmentation trainers 428# 429 430 431def default_segmentation_trainer( 432 name: str, 433 model: torch.nn.Module, 434 train_loader: DataLoader, 435 val_loader: DataLoader, 436 loss: Optional[torch.nn.Module] = None, 437 metric: Optional[Callable] = None, 438 learning_rate: float = 1e-3, 439 device: Optional[Union[str, torch.device]] = None, 440 log_image_interval: int = 100, 441 mixed_precision: bool = True, 442 early_stopping: Optional[int] = None, 443 logger=TensorboardLogger, 444 logger_kwargs: Optional[Dict[str, Any]] = None, 445 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 446 optimizer_kwargs: Dict[str, Any] = {}, 447 trainer_class=DefaultTrainer, 448 id_: Optional[str] = None, 449 save_root: Optional[str] = None, 450 compile_model: Optional[Union[bool, str]] = None, 451 rank: Optional[int] = None, 452): 453 """Get a trainer for a segmentation network. 454 455 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 456 By default, it uses the dice score as loss and metric. 457 This can be changed by passing arguments for `loss` and/or `metric`. 458 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 459 460 Here's an example for training a 2D U-Net with this function: 461 ```python 462 import torch_em 463 from torch_em.model import UNet2d 464 from torch_em.data.datasets.light_microscopy import get_dsb_loader 465 466 # The training data will be downloaded to this location. 467 data_root = "/path/to/save/the/training/data" 468 patch_shape = (256, 256) 469 trainer = default_segmentation_trainer( 470 name="unet-training" 471 model=UNet2d(in_channels=1, out_channels=1) 472 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 473 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 474 ) 475 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 476 ``` 477 478 Args: 479 name: The name of the checkpoint that will be created by the trainer. 480 model: The model to train. 481 train_loader: The data loader containing the training data. 482 val_loader: The data loader containing the validation data. 483 loss: The loss function for training. 484 metric: The metric for validation. 485 learning_rate: The initial learning rate for the AdamW optimizer. 486 device: The torch device to use for training. If None, will use a GPU if available. 487 log_image_interval: The interval for saving images during logging, in training iterations. 488 mixed_precision: Whether to train with mixed precision. 489 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 490 logger: The logger class. Will be instantiated for logging. 491 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 492 logger_kwargs: The keyword arguments for the logger class. 493 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 494 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 495 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 496 but can be set to a custom trainer class to enable custom training procedures. 497 id_: Unique identifier for the trainer. If None then `name` will be used. 498 save_root: The root folder for saving the checkpoint and logs. 499 compile_model: Whether to compile the model before training. 500 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 501 502 Returns: 503 The trainer. 504 """ 505 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 506 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 507 508 loss = DiceLoss() if loss is None else loss 509 metric = DiceLoss() if metric is None else metric 510 511 if device is None: 512 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 513 else: 514 device = torch.device(device) 515 516 # CPU does not support mixed precision training. 517 if device.type == "cpu": 518 mixed_precision = False 519 520 return trainer_class( 521 name=name, 522 model=model, 523 train_loader=train_loader, 524 val_loader=val_loader, 525 loss=loss, 526 metric=metric, 527 optimizer=optimizer, 528 device=device, 529 lr_scheduler=scheduler, 530 mixed_precision=mixed_precision, 531 early_stopping=early_stopping, 532 log_image_interval=log_image_interval, 533 logger=logger, 534 logger_kwargs=logger_kwargs, 535 id_=id_, 536 save_root=save_root, 537 compile_model=compile_model, 538 rank=rank, 539 )
205def default_segmentation_loader( 206 raw_paths: Union[List[Any], str, os.PathLike], 207 raw_key: Optional[str], 208 label_paths: Union[List[Any], str, os.PathLike], 209 label_key: Optional[str], 210 batch_size: int, 211 patch_shape: Tuple[int, ...], 212 label_transform: Optional[Callable] = None, 213 label_transform2: Optional[Callable] = None, 214 raw_transform: Optional[Callable] = None, 215 transform: Optional[Callable] = None, 216 dtype: torch.device = torch.float32, 217 label_dtype: torch.device = torch.float32, 218 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 219 n_samples: Optional[int] = None, 220 sampler: Optional[Callable] = None, 221 ndim: Optional[int] = None, 222 is_seg_dataset: Optional[bool] = None, 223 with_channels: bool = False, 224 with_label_channels: bool = False, 225 verify_paths: bool = True, 226 with_padding: bool = True, 227 z_ext: Optional[int] = None, 228 **loader_kwargs, 229) -> torch.utils.data.DataLoader: 230 """Get data loader for training a segmentation network. 231 232 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 233 on the data formats that are supported. 234 235 Args: 236 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 237 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 238 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 239 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 240 batch_size: The batch size for the data loader. 241 patch_shape: The patch shape for the training samples. 242 label_transform: Transformation applied to the label data of a sample, 243 before applying augmentations via `transform`. 244 label_transform2: Transformation applied to the label data of a sample, 245 after applying augmentations via `transform`. 246 raw_transform: Transformation applied to the raw data of a sample, 247 before applying augmentations via `transform`. 248 transform: Transformation applied to both the raw data and label data of a sample. 249 This can be used to implement data augmentations. 250 dtype: The return data type of the raw data. 251 label_dtype: The return data type of the label data. 252 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 253 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 254 sampler: Sampler for rejecting samples according to a defined criterion. 255 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 256 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 257 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 258 If None, the type of dataset will be derived from the data. 259 with_channels: Whether the raw data has channels. 260 with_label_channels: Whether the label data has channels. 261 verify_paths: Whether to verify all paths before creating the dataset. 262 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 263 z_ext: Extra bounding box for loading the data across z. 264 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 265 266 Returns: 267 The torch data loader. 268 """ 269 ds = default_segmentation_dataset( 270 raw_paths=raw_paths, 271 raw_key=raw_key, 272 label_paths=label_paths, 273 label_key=label_key, 274 patch_shape=patch_shape, 275 label_transform=label_transform, 276 label_transform2=label_transform2, 277 raw_transform=raw_transform, 278 transform=transform, 279 dtype=dtype, 280 label_dtype=label_dtype, 281 rois=rois, 282 n_samples=n_samples, 283 sampler=sampler, 284 ndim=ndim, 285 is_seg_dataset=is_seg_dataset, 286 with_channels=with_channels, 287 with_label_channels=with_label_channels, 288 with_padding=with_padding, 289 z_ext=z_ext, 290 verify_paths=verify_paths, 291 ) 292 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get data loader for training a segmentation network.
See torch_em.data.SegmentationDataset
and torch_em.data.ImageCollectionDataset
for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the underlying dataset. If None, the length will be set to
len(raw_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder
.
Returns:
The torch data loader.
295def default_segmentation_dataset( 296 raw_paths: Union[List[Any], str, os.PathLike], 297 raw_key: Optional[str], 298 label_paths: Union[List[Any], str, os.PathLike], 299 label_key: Optional[str], 300 patch_shape: Tuple[int, ...], 301 label_transform: Optional[Callable] = None, 302 label_transform2: Optional[Callable] = None, 303 raw_transform: Optional[Callable] = None, 304 transform: Optional[Callable] = None, 305 dtype: torch.dtype = torch.float32, 306 label_dtype: torch.dtype = torch.float32, 307 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 308 n_samples: Optional[int] = None, 309 sampler: Optional[Callable] = None, 310 ndim: Optional[int] = None, 311 is_seg_dataset: Optional[bool] = None, 312 with_channels: bool = False, 313 with_label_channels: bool = False, 314 verify_paths: bool = True, 315 with_padding: bool = True, 316 z_ext: Optional[int] = None, 317) -> torch.utils.data.Dataset: 318 """Get data set for training a segmentation network. 319 320 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 321 on the data formats that are supported. 322 323 Args: 324 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 325 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 326 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 327 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 328 patch_shape: The patch shape for the training samples. 329 label_transform: Transformation applied to the label data of a sample, 330 before applying augmentations via `transform`. 331 label_transform2: Transformation applied to the label data of a sample, 332 after applying augmentations via `transform`. 333 raw_transform: Transformation applied to the raw data of a sample, 334 before applying augmentations via `transform`. 335 transform: Transformation applied to both the raw data and label data of a sample. 336 This can be used to implement data augmentations. 337 dtype: The return data type of the raw data. 338 label_dtype: The return data type of the label data. 339 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 340 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 341 sampler: Sampler for rejecting samples according to a defined criterion. 342 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 343 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 344 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 345 If None, the type of dataset will be derived from the data. 346 with_channels: Whether the raw data has channels. 347 with_label_channels: Whether the label data has channels. 348 verify_paths: Whether to verify all paths before creating the dataset. 349 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 350 z_ext: Extra bounding box for loading the data across z. 351 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 352 353 Returns: 354 The torch dataset. 355 """ 356 if verify_paths: 357 check_paths(raw_paths, label_paths) 358 359 if is_seg_dataset is None: 360 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 361 362 # We always use a raw transform in the convenience function. 363 if raw_transform is None: 364 raw_transform = get_raw_transform() 365 366 # We always use augmentations in the convenience function. 367 if transform is None: 368 transform = _get_default_transform( 369 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 370 ) 371 372 if is_seg_dataset: 373 ds = _load_segmentation_dataset( 374 raw_paths, 375 raw_key, 376 label_paths, 377 label_key, 378 patch_shape=patch_shape, 379 raw_transform=raw_transform, 380 label_transform=label_transform, 381 label_transform2=label_transform2, 382 transform=transform, 383 rois=rois, 384 n_samples=n_samples, 385 sampler=sampler, 386 ndim=ndim, 387 dtype=dtype, 388 label_dtype=label_dtype, 389 with_channels=with_channels, 390 with_label_channels=with_label_channels, 391 with_padding=with_padding, 392 z_ext=z_ext, 393 ) 394 395 else: 396 ds = _load_image_collection_dataset( 397 raw_paths, 398 raw_key, 399 label_paths, 400 label_key, 401 roi=rois, 402 patch_shape=patch_shape, 403 label_transform=label_transform, 404 raw_transform=raw_transform, 405 label_transform2=label_transform2, 406 transform=transform, 407 n_samples=n_samples, 408 sampler=sampler, 409 dtype=dtype, 410 label_dtype=label_dtype, 411 with_padding=with_padding, 412 ) 413 414 return ds
Get data set for training a segmentation network.
See torch_em.data.SegmentationDataset
and torch_em.data.ImageCollectionDataset
for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the dataset. If None, the length will be set to
len(raw_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder
.
Returns:
The torch dataset.
432def default_segmentation_trainer( 433 name: str, 434 model: torch.nn.Module, 435 train_loader: DataLoader, 436 val_loader: DataLoader, 437 loss: Optional[torch.nn.Module] = None, 438 metric: Optional[Callable] = None, 439 learning_rate: float = 1e-3, 440 device: Optional[Union[str, torch.device]] = None, 441 log_image_interval: int = 100, 442 mixed_precision: bool = True, 443 early_stopping: Optional[int] = None, 444 logger=TensorboardLogger, 445 logger_kwargs: Optional[Dict[str, Any]] = None, 446 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 447 optimizer_kwargs: Dict[str, Any] = {}, 448 trainer_class=DefaultTrainer, 449 id_: Optional[str] = None, 450 save_root: Optional[str] = None, 451 compile_model: Optional[Union[bool, str]] = None, 452 rank: Optional[int] = None, 453): 454 """Get a trainer for a segmentation network. 455 456 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 457 By default, it uses the dice score as loss and metric. 458 This can be changed by passing arguments for `loss` and/or `metric`. 459 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 460 461 Here's an example for training a 2D U-Net with this function: 462 ```python 463 import torch_em 464 from torch_em.model import UNet2d 465 from torch_em.data.datasets.light_microscopy import get_dsb_loader 466 467 # The training data will be downloaded to this location. 468 data_root = "/path/to/save/the/training/data" 469 patch_shape = (256, 256) 470 trainer = default_segmentation_trainer( 471 name="unet-training" 472 model=UNet2d(in_channels=1, out_channels=1) 473 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 474 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 475 ) 476 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 477 ``` 478 479 Args: 480 name: The name of the checkpoint that will be created by the trainer. 481 model: The model to train. 482 train_loader: The data loader containing the training data. 483 val_loader: The data loader containing the validation data. 484 loss: The loss function for training. 485 metric: The metric for validation. 486 learning_rate: The initial learning rate for the AdamW optimizer. 487 device: The torch device to use for training. If None, will use a GPU if available. 488 log_image_interval: The interval for saving images during logging, in training iterations. 489 mixed_precision: Whether to train with mixed precision. 490 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 491 logger: The logger class. Will be instantiated for logging. 492 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 493 logger_kwargs: The keyword arguments for the logger class. 494 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 495 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 496 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 497 but can be set to a custom trainer class to enable custom training procedures. 498 id_: Unique identifier for the trainer. If None then `name` will be used. 499 save_root: The root folder for saving the checkpoint and logs. 500 compile_model: Whether to compile the model before training. 501 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 502 503 Returns: 504 The trainer. 505 """ 506 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 507 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 508 509 loss = DiceLoss() if loss is None else loss 510 metric = DiceLoss() if metric is None else metric 511 512 if device is None: 513 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 514 else: 515 device = torch.device(device) 516 517 # CPU does not support mixed precision training. 518 if device.type == "cpu": 519 mixed_precision = False 520 521 return trainer_class( 522 name=name, 523 model=model, 524 train_loader=train_loader, 525 val_loader=val_loader, 526 loss=loss, 527 metric=metric, 528 optimizer=optimizer, 529 device=device, 530 lr_scheduler=scheduler, 531 mixed_precision=mixed_precision, 532 early_stopping=early_stopping, 533 log_image_interval=log_image_interval, 534 logger=logger, 535 logger_kwargs=logger_kwargs, 536 id_=id_, 537 save_root=save_root, 538 compile_model=compile_model, 539 rank=rank, 540 )
Get a trainer for a segmentation network.
It creates a torch.optim.AdamW
optimizer and learning rate scheduler that reduces the learning rate on plateau.
By default, it uses the dice score as loss and metric.
This can be changed by passing arguments for loss
and/or metric
.
See torch_em.trainer.DefaultTrainer
for additional details on how to configure and use the trainer.
Here's an example for training a 2D U-Net with this function:
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
name="unet-training"
model=UNet2d(in_channels=1, out_channels=1)
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- model: The model to train.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- loss: The loss function for training.
- metric: The metric for validation.
- learning_rate: The initial learning rate for the AdamW optimizer.
- device: The torch device to use for training. If None, will use a GPU if available.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger
. - logger_kwargs: The keyword arguments for the logger class.
- scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
- optimizer_kwargs: The keyword arguments for the AdamW optimizer.
- trainer_class: The trainer class. Uses
torch_em.trainer.DefaultTrainer
by default, but can be set to a custom trainer class to enable custom training procedures. - id_: Unique identifier for the trainer. If None then
name
will be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_training
for details.
Returns:
The trainer.