torch_em.segmentation
1import os 2from glob import glob 3from typing import Any, Dict, Optional, Union, Tuple, List, Callable 4 5import numpy as np 6import torch 7import torch.utils.data 8from torch.utils.data import DataLoader 9 10from .loss import DiceLoss 11from .util import load_data 12from .trainer import DefaultTrainer 13from .trainer.tensorboard_logger import TensorboardLogger 14from .transform import get_augmentations, get_raw_transform 15from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset, TensorDataset 16 17 18# TODO add a heuristic to estimate this from the number of epochs 19DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5} 20"""@private 21""" 22 23 24# 25# convenience functions for segmentation loaders 26# 27 28# TODO implement balanced and make it the default 29# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"): 30def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 31 """@private 32 """ 33 assert split in ("balanced", "uniform") 34 n_datasets = len(raw_paths) 35 if split == "uniform": 36 # even distribution of samples to datasets 37 samples_per_ds = n_samples // n_datasets 38 divider = n_samples % n_datasets 39 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 40 else: 41 # distribution of samples to dataset based on the dataset lens 42 raise NotImplementedError 43 44 45def check_paths(raw_paths, label_paths): 46 """@private 47 """ 48 if not isinstance(raw_paths, type(label_paths)): 49 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 50 51 # This is a tensor dataset and we don't need to verify the paths. 52 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (torch.Tensor, np.ndarray)): 53 return 54 55 def _check_path(path): 56 if isinstance(path, str): 57 if not os.path.exists(path): 58 raise ValueError(f"Could not find path {path}") 59 else: 60 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 61 for per_path in path: 62 if not os.path.exists(per_path): 63 raise ValueError(f"Could not find path {per_path}") 64 65 if isinstance(raw_paths, str): 66 _check_path(raw_paths) 67 _check_path(label_paths) 68 else: 69 if len(raw_paths) != len(label_paths): 70 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 71 for rp, lp in zip(raw_paths, label_paths): 72 _check_path(rp) 73 _check_path(lp) 74 75 76# Check if we can load the data as SegmentationDataset. 77def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 78 """@private 79 """ 80 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)): 81 if not all(isinstance(rp, (np.ndarray, torch.Tensor)) for rp in raw_paths): 82 raise ValueError("Inconsistent raw data") 83 if not all(isinstance(lp, (np.ndarray, torch.Tensor)) for lp in label_paths): 84 raise ValueError("Inconsistent label data") 85 return False 86 87 def _can_open(path, key): 88 try: 89 load_data(path, key) 90 return True 91 except Exception: 92 return False 93 94 if isinstance(raw_paths, str): 95 can_open_raw = _can_open(raw_paths, raw_key) 96 can_open_label = _can_open(label_paths, label_key) 97 else: 98 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 99 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 100 raise ValueError("Inconsistent raw data") 101 can_open_raw = can_open_raw[0] 102 103 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 104 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 105 raise ValueError("Inconsistent label data") 106 can_open_label = can_open_label[0] 107 108 if can_open_raw != can_open_label: 109 raise ValueError("Inconsistent raw and label data") 110 111 return can_open_raw 112 113 114def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs): 115 rois = kwargs.pop("rois", None) 116 if isinstance(raw_paths, str): 117 if rois is not None: 118 assert isinstance(rois, (tuple, slice)) 119 if isinstance(rois, tuple): 120 assert all(isinstance(roi, slice) for roi in rois) 121 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) 122 else: 123 assert len(raw_paths) > 0 124 if rois is not None: 125 assert len(rois) == len(label_paths) 126 assert all(isinstance(roi, tuple) for roi in rois), f"{rois}" 127 n_samples = kwargs.pop("n_samples", None) 128 129 samples_per_ds = ( 130 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 131 ) 132 ds = [] 133 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 134 roi = None if rois is None else rois[i] 135 dset = SegmentationDataset( 136 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs 137 ) 138 ds.append(dset) 139 ds = ConcatDataset(*ds) 140 return ds 141 142 143def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, with_channels, **kwargs): 144 if isinstance(raw_paths[0], (torch.Tensor, np.ndarray)): 145 assert raw_key is None and label_key is None 146 assert roi is None 147 kwargs.pop("pre_label_transform") # NOTE: The 'TensorDataset' currently does not support samplers. 148 return TensorDataset(raw_paths, label_paths, with_channels=with_channels, **kwargs) 149 150 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 151 rpath = glob(os.path.join(rpath, rkey)) 152 rpath.sort() 153 if len(rpath) == 0: 154 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 155 156 lpath = glob(os.path.join(lpath, lkey)) 157 lpath.sort() 158 if len(rpath) != len(lpath): 159 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 160 161 if this_roi is not None: 162 rpath, lpath = rpath[roi], lpath[roi] 163 164 return rpath, lpath 165 166 patch_shape = kwargs.pop("patch_shape") 167 if patch_shape is not None: 168 if len(patch_shape) == 3: 169 if patch_shape[0] != 1: 170 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 171 patch_shape = patch_shape[1:] 172 assert len(patch_shape) == 2 173 174 if isinstance(raw_paths, str): 175 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 176 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 177 178 elif raw_key is None: 179 assert label_key is None 180 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 181 assert len(raw_paths) == len(label_paths) 182 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 183 184 else: 185 ds = [] 186 n_samples = kwargs.pop("n_samples", None) 187 samples_per_ds = ( 188 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 189 ) 190 if roi is None: 191 roi = len(raw_paths) * [None] 192 assert len(roi) == len(raw_paths) 193 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 194 print(raw_path, label_path, this_roi) 195 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 196 dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs) 197 ds.append(dset) 198 ds = ConcatDataset(*ds) 199 200 return ds 201 202 203def _get_default_transform(path, key, is_seg_dataset, ndim): 204 if is_seg_dataset and ndim is None: 205 shape = load_data(path, key).shape 206 if len(shape) == 2: 207 ndim = 2 208 else: 209 # heuristics to figure out whether to use default 3d 210 # or default anisotropic augmentations 211 ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 212 213 elif is_seg_dataset and ndim is not None: 214 pass 215 216 else: 217 ndim = 2 218 219 return get_augmentations(ndim) 220 221 222def default_segmentation_loader( 223 raw_paths: Union[List[Any], str, os.PathLike], 224 raw_key: Optional[str], 225 label_paths: Union[List[Any], str, os.PathLike], 226 label_key: Optional[str], 227 batch_size: int, 228 patch_shape: Tuple[int, ...], 229 label_transform: Optional[Callable] = None, 230 label_transform2: Optional[Callable] = None, 231 raw_transform: Optional[Callable] = None, 232 transform: Optional[Callable] = None, 233 dtype: torch.device = torch.float32, 234 label_dtype: torch.device = torch.float32, 235 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 236 n_samples: Optional[int] = None, 237 sampler: Optional[Callable] = None, 238 ndim: Optional[int] = None, 239 is_seg_dataset: Optional[bool] = None, 240 with_channels: bool = False, 241 with_label_channels: bool = False, 242 verify_paths: bool = True, 243 with_padding: bool = True, 244 z_ext: Optional[int] = None, 245 pre_label_transform: Optional[Callable] = None, 246 **loader_kwargs, 247) -> torch.utils.data.DataLoader: 248 """Get data loader for training a segmentation network. 249 250 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 251 on the data formats that are supported. 252 253 Args: 254 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 255 This argument also accepts a list of numpy arrays or torch tensors. 256 raw_key: The name of the internal dataset containing the raw data. 257 Set to None for regular image files, numpy arrays, or torch tensors. 258 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 259 This argument also accepts a list of numpy arrays or torch tensors. 260 label_key: The name of the internal dataset containing the raw data. 261 Set to None for regular image files, numpy arrays, or torch tensors. 262 batch_size: The batch size for the data loader. 263 patch_shape: The patch shape for the training samples. 264 label_transform: Transformation applied to the label data of a sample, 265 before applying augmentations via `transform`. 266 label_transform2: Transformation applied to the label data of a sample, 267 after applying augmentations via `transform`. 268 raw_transform: Transformation applied to the raw data of a sample, 269 before applying augmentations via `transform`. 270 transform: Transformation applied to both the raw data and label data of a sample. 271 This can be used to implement data augmentations. 272 dtype: The return data type of the raw data. 273 label_dtype: The return data type of the label data. 274 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 275 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 276 sampler: Sampler for rejecting samples according to a defined criterion. 277 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 278 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 279 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 280 If None, the type of dataset will be derived from the data. 281 with_channels: Whether the raw data has channels. 282 with_label_channels: Whether the label data has channels. 283 verify_paths: Whether to verify all paths before creating the dataset. 284 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 285 z_ext: Extra bounding box for loading the data across z. 286 pre_label_transform: Transformation applied to the label data of a chosen random sample, 287 before applying the sample validity via the `sampler`. 288 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 289 290 Returns: 291 The torch data loader. 292 """ 293 ds = default_segmentation_dataset( 294 raw_paths=raw_paths, 295 raw_key=raw_key, 296 label_paths=label_paths, 297 label_key=label_key, 298 patch_shape=patch_shape, 299 label_transform=label_transform, 300 label_transform2=label_transform2, 301 raw_transform=raw_transform, 302 transform=transform, 303 dtype=dtype, 304 label_dtype=label_dtype, 305 rois=rois, 306 n_samples=n_samples, 307 sampler=sampler, 308 ndim=ndim, 309 is_seg_dataset=is_seg_dataset, 310 with_channels=with_channels, 311 with_label_channels=with_label_channels, 312 with_padding=with_padding, 313 z_ext=z_ext, 314 verify_paths=verify_paths, 315 pre_label_transform=pre_label_transform, 316 ) 317 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) 318 319 320def default_segmentation_dataset( 321 raw_paths: Union[List[Any], str, os.PathLike], 322 raw_key: Optional[str], 323 label_paths: Union[List[Any], str, os.PathLike], 324 label_key: Optional[str], 325 patch_shape: Tuple[int, ...], 326 label_transform: Optional[Callable] = None, 327 label_transform2: Optional[Callable] = None, 328 raw_transform: Optional[Callable] = None, 329 transform: Optional[Callable] = None, 330 dtype: torch.dtype = torch.float32, 331 label_dtype: torch.dtype = torch.float32, 332 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 333 n_samples: Optional[int] = None, 334 sampler: Optional[Callable] = None, 335 ndim: Optional[int] = None, 336 is_seg_dataset: Optional[bool] = None, 337 with_channels: bool = False, 338 with_label_channels: bool = False, 339 verify_paths: bool = True, 340 with_padding: bool = True, 341 z_ext: Optional[int] = None, 342 pre_label_transform: Optional[Callable] = None, 343) -> torch.utils.data.Dataset: 344 """Get data set for training a segmentation network. 345 346 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 347 on the data formats that are supported. 348 349 Args: 350 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 351 This argument also accepts a list of numpy arrays or torch tensors. 352 raw_key: The name of the internal dataset containing the raw data. 353 Set to None for regular image files, numpy arrays, or torch tensors. 354 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 355 This argument also accepts a list of numpy arrays or torch tensors. 356 label_key: The name of the internal dataset containing the raw data. 357 Set to None for regular image files, numpy arrays, or torch tensors. 358 patch_shape: The patch shape for the training samples. 359 label_transform: Transformation applied to the label data of a sample, 360 before applying augmentations via `transform`. 361 label_transform2: Transformation applied to the label data of a sample, 362 after applying augmentations via `transform`. 363 raw_transform: Transformation applied to the raw data of a sample, 364 before applying augmentations via `transform`. 365 transform: Transformation applied to both the raw data and label data of a sample. 366 This can be used to implement data augmentations. 367 dtype: The return data type of the raw data. 368 label_dtype: The return data type of the label data. 369 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 370 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 371 sampler: Sampler for rejecting samples according to a defined criterion. 372 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 373 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 374 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 375 If None, the type of dataset will be derived from the data. 376 with_channels: Whether the raw data has channels. 377 with_label_channels: Whether the label data has channels. 378 verify_paths: Whether to verify all paths before creating the dataset. 379 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 380 z_ext: Extra bounding box for loading the data across z. 381 pre_label_transform: Transformation applied to the label data of a chosen random sample, 382 before applying the sample validity via the `sampler`. 383 384 Returns: 385 The torch dataset. 386 """ 387 if verify_paths: 388 check_paths(raw_paths, label_paths) 389 390 if is_seg_dataset is None: 391 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 392 393 # We always use a raw transform in the convenience function. 394 if raw_transform is None: 395 raw_transform = get_raw_transform() 396 397 # We always use augmentations in the convenience function. 398 if transform is None: 399 transform = _get_default_transform( 400 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 401 ) 402 403 if is_seg_dataset: 404 ds = _load_segmentation_dataset( 405 raw_paths, 406 raw_key, 407 label_paths, 408 label_key, 409 patch_shape=patch_shape, 410 raw_transform=raw_transform, 411 label_transform=label_transform, 412 label_transform2=label_transform2, 413 transform=transform, 414 rois=rois, 415 n_samples=n_samples, 416 sampler=sampler, 417 ndim=ndim, 418 dtype=dtype, 419 label_dtype=label_dtype, 420 with_channels=with_channels, 421 with_label_channels=with_label_channels, 422 with_padding=with_padding, 423 z_ext=z_ext, 424 pre_label_transform=pre_label_transform, 425 ) 426 427 else: 428 ds = _load_image_collection_dataset( 429 raw_paths, 430 raw_key, 431 label_paths, 432 label_key, 433 roi=rois, 434 patch_shape=patch_shape, 435 label_transform=label_transform, 436 raw_transform=raw_transform, 437 label_transform2=label_transform2, 438 transform=transform, 439 n_samples=n_samples, 440 sampler=sampler, 441 dtype=dtype, 442 label_dtype=label_dtype, 443 with_padding=with_padding, 444 with_channels=with_channels, 445 pre_label_transform=pre_label_transform, 446 ) 447 448 return ds 449 450 451def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader: 452 """@private 453 """ 454 pin_memory = loader_kwargs.pop("pin_memory", True) 455 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs) 456 # monkey patch shuffle attribute to the loader 457 loader.shuffle = loader_kwargs.get("shuffle", False) 458 return loader 459 460 461# 462# convenience functions for segmentation trainers 463# 464 465 466def default_segmentation_trainer( 467 name: str, 468 model: torch.nn.Module, 469 train_loader: DataLoader, 470 val_loader: DataLoader, 471 loss: Optional[torch.nn.Module] = None, 472 metric: Optional[Callable] = None, 473 learning_rate: float = 1e-3, 474 device: Optional[Union[str, torch.device]] = None, 475 log_image_interval: int = 100, 476 mixed_precision: bool = True, 477 early_stopping: Optional[int] = None, 478 logger=TensorboardLogger, 479 logger_kwargs: Optional[Dict[str, Any]] = None, 480 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 481 optimizer_kwargs: Dict[str, Any] = {}, 482 trainer_class=DefaultTrainer, 483 id_: Optional[str] = None, 484 save_root: Optional[str] = None, 485 compile_model: Optional[Union[bool, str]] = None, 486 rank: Optional[int] = None, 487): 488 """Get a trainer for a segmentation network. 489 490 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 491 By default, it uses the dice score as loss and metric. 492 This can be changed by passing arguments for `loss` and/or `metric`. 493 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 494 495 Here's an example for training a 2D U-Net with this function: 496 ```python 497 import torch_em 498 from torch_em.model import UNet2d 499 from torch_em.data.datasets.light_microscopy import get_dsb_loader 500 501 # The training data will be downloaded to this location. 502 data_root = "/path/to/save/the/training/data" 503 patch_shape = (256, 256) 504 trainer = default_segmentation_trainer( 505 name="unet-training" 506 model=UNet2d(in_channels=1, out_channels=1) 507 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 508 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 509 ) 510 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 511 ``` 512 513 Args: 514 name: The name of the checkpoint that will be created by the trainer. 515 model: The model to train. 516 train_loader: The data loader containing the training data. 517 val_loader: The data loader containing the validation data. 518 loss: The loss function for training. 519 metric: The metric for validation. 520 learning_rate: The initial learning rate for the AdamW optimizer. 521 device: The torch device to use for training. If None, will use a GPU if available. 522 log_image_interval: The interval for saving images during logging, in training iterations. 523 mixed_precision: Whether to train with mixed precision. 524 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 525 logger: The logger class. Will be instantiated for logging. 526 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 527 logger_kwargs: The keyword arguments for the logger class. 528 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 529 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 530 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 531 but can be set to a custom trainer class to enable custom training procedures. 532 id_: Unique identifier for the trainer. If None then `name` will be used. 533 save_root: The root folder for saving the checkpoint and logs. 534 compile_model: Whether to compile the model before training. 535 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 536 537 Returns: 538 The trainer. 539 """ 540 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 541 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 542 543 loss = DiceLoss() if loss is None else loss 544 metric = DiceLoss() if metric is None else metric 545 546 if device is None: 547 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 548 else: 549 device = torch.device(device) 550 551 # CPU does not support mixed precision training. 552 if device.type == "cpu": 553 mixed_precision = False 554 555 return trainer_class( 556 name=name, 557 model=model, 558 train_loader=train_loader, 559 val_loader=val_loader, 560 loss=loss, 561 metric=metric, 562 optimizer=optimizer, 563 device=device, 564 lr_scheduler=scheduler, 565 mixed_precision=mixed_precision, 566 early_stopping=early_stopping, 567 log_image_interval=log_image_interval, 568 logger=logger, 569 logger_kwargs=logger_kwargs, 570 id_=id_, 571 save_root=save_root, 572 compile_model=compile_model, 573 rank=rank, 574 )
223def default_segmentation_loader( 224 raw_paths: Union[List[Any], str, os.PathLike], 225 raw_key: Optional[str], 226 label_paths: Union[List[Any], str, os.PathLike], 227 label_key: Optional[str], 228 batch_size: int, 229 patch_shape: Tuple[int, ...], 230 label_transform: Optional[Callable] = None, 231 label_transform2: Optional[Callable] = None, 232 raw_transform: Optional[Callable] = None, 233 transform: Optional[Callable] = None, 234 dtype: torch.device = torch.float32, 235 label_dtype: torch.device = torch.float32, 236 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 237 n_samples: Optional[int] = None, 238 sampler: Optional[Callable] = None, 239 ndim: Optional[int] = None, 240 is_seg_dataset: Optional[bool] = None, 241 with_channels: bool = False, 242 with_label_channels: bool = False, 243 verify_paths: bool = True, 244 with_padding: bool = True, 245 z_ext: Optional[int] = None, 246 pre_label_transform: Optional[Callable] = None, 247 **loader_kwargs, 248) -> torch.utils.data.DataLoader: 249 """Get data loader for training a segmentation network. 250 251 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 252 on the data formats that are supported. 253 254 Args: 255 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 256 This argument also accepts a list of numpy arrays or torch tensors. 257 raw_key: The name of the internal dataset containing the raw data. 258 Set to None for regular image files, numpy arrays, or torch tensors. 259 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 260 This argument also accepts a list of numpy arrays or torch tensors. 261 label_key: The name of the internal dataset containing the raw data. 262 Set to None for regular image files, numpy arrays, or torch tensors. 263 batch_size: The batch size for the data loader. 264 patch_shape: The patch shape for the training samples. 265 label_transform: Transformation applied to the label data of a sample, 266 before applying augmentations via `transform`. 267 label_transform2: Transformation applied to the label data of a sample, 268 after applying augmentations via `transform`. 269 raw_transform: Transformation applied to the raw data of a sample, 270 before applying augmentations via `transform`. 271 transform: Transformation applied to both the raw data and label data of a sample. 272 This can be used to implement data augmentations. 273 dtype: The return data type of the raw data. 274 label_dtype: The return data type of the label data. 275 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 276 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 277 sampler: Sampler for rejecting samples according to a defined criterion. 278 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 279 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 280 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 281 If None, the type of dataset will be derived from the data. 282 with_channels: Whether the raw data has channels. 283 with_label_channels: Whether the label data has channels. 284 verify_paths: Whether to verify all paths before creating the dataset. 285 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 286 z_ext: Extra bounding box for loading the data across z. 287 pre_label_transform: Transformation applied to the label data of a chosen random sample, 288 before applying the sample validity via the `sampler`. 289 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 290 291 Returns: 292 The torch data loader. 293 """ 294 ds = default_segmentation_dataset( 295 raw_paths=raw_paths, 296 raw_key=raw_key, 297 label_paths=label_paths, 298 label_key=label_key, 299 patch_shape=patch_shape, 300 label_transform=label_transform, 301 label_transform2=label_transform2, 302 raw_transform=raw_transform, 303 transform=transform, 304 dtype=dtype, 305 label_dtype=label_dtype, 306 rois=rois, 307 n_samples=n_samples, 308 sampler=sampler, 309 ndim=ndim, 310 is_seg_dataset=is_seg_dataset, 311 with_channels=with_channels, 312 with_label_channels=with_label_channels, 313 with_padding=with_padding, 314 z_ext=z_ext, 315 verify_paths=verify_paths, 316 pre_label_transform=pre_label_transform, 317 ) 318 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get data loader for training a segmentation network.
See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the underlying dataset. If None, the length will be set to
len(raw_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- pre_label_transform: Transformation applied to the label data of a chosen random sample,
before applying the sample validity via the
sampler. - loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder.
Returns:
The torch data loader.
321def default_segmentation_dataset( 322 raw_paths: Union[List[Any], str, os.PathLike], 323 raw_key: Optional[str], 324 label_paths: Union[List[Any], str, os.PathLike], 325 label_key: Optional[str], 326 patch_shape: Tuple[int, ...], 327 label_transform: Optional[Callable] = None, 328 label_transform2: Optional[Callable] = None, 329 raw_transform: Optional[Callable] = None, 330 transform: Optional[Callable] = None, 331 dtype: torch.dtype = torch.float32, 332 label_dtype: torch.dtype = torch.float32, 333 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 334 n_samples: Optional[int] = None, 335 sampler: Optional[Callable] = None, 336 ndim: Optional[int] = None, 337 is_seg_dataset: Optional[bool] = None, 338 with_channels: bool = False, 339 with_label_channels: bool = False, 340 verify_paths: bool = True, 341 with_padding: bool = True, 342 z_ext: Optional[int] = None, 343 pre_label_transform: Optional[Callable] = None, 344) -> torch.utils.data.Dataset: 345 """Get data set for training a segmentation network. 346 347 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 348 on the data formats that are supported. 349 350 Args: 351 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 352 This argument also accepts a list of numpy arrays or torch tensors. 353 raw_key: The name of the internal dataset containing the raw data. 354 Set to None for regular image files, numpy arrays, or torch tensors. 355 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 356 This argument also accepts a list of numpy arrays or torch tensors. 357 label_key: The name of the internal dataset containing the raw data. 358 Set to None for regular image files, numpy arrays, or torch tensors. 359 patch_shape: The patch shape for the training samples. 360 label_transform: Transformation applied to the label data of a sample, 361 before applying augmentations via `transform`. 362 label_transform2: Transformation applied to the label data of a sample, 363 after applying augmentations via `transform`. 364 raw_transform: Transformation applied to the raw data of a sample, 365 before applying augmentations via `transform`. 366 transform: Transformation applied to both the raw data and label data of a sample. 367 This can be used to implement data augmentations. 368 dtype: The return data type of the raw data. 369 label_dtype: The return data type of the label data. 370 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 371 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 372 sampler: Sampler for rejecting samples according to a defined criterion. 373 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 374 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 375 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 376 If None, the type of dataset will be derived from the data. 377 with_channels: Whether the raw data has channels. 378 with_label_channels: Whether the label data has channels. 379 verify_paths: Whether to verify all paths before creating the dataset. 380 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 381 z_ext: Extra bounding box for loading the data across z. 382 pre_label_transform: Transformation applied to the label data of a chosen random sample, 383 before applying the sample validity via the `sampler`. 384 385 Returns: 386 The torch dataset. 387 """ 388 if verify_paths: 389 check_paths(raw_paths, label_paths) 390 391 if is_seg_dataset is None: 392 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 393 394 # We always use a raw transform in the convenience function. 395 if raw_transform is None: 396 raw_transform = get_raw_transform() 397 398 # We always use augmentations in the convenience function. 399 if transform is None: 400 transform = _get_default_transform( 401 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 402 ) 403 404 if is_seg_dataset: 405 ds = _load_segmentation_dataset( 406 raw_paths, 407 raw_key, 408 label_paths, 409 label_key, 410 patch_shape=patch_shape, 411 raw_transform=raw_transform, 412 label_transform=label_transform, 413 label_transform2=label_transform2, 414 transform=transform, 415 rois=rois, 416 n_samples=n_samples, 417 sampler=sampler, 418 ndim=ndim, 419 dtype=dtype, 420 label_dtype=label_dtype, 421 with_channels=with_channels, 422 with_label_channels=with_label_channels, 423 with_padding=with_padding, 424 z_ext=z_ext, 425 pre_label_transform=pre_label_transform, 426 ) 427 428 else: 429 ds = _load_image_collection_dataset( 430 raw_paths, 431 raw_key, 432 label_paths, 433 label_key, 434 roi=rois, 435 patch_shape=patch_shape, 436 label_transform=label_transform, 437 raw_transform=raw_transform, 438 label_transform2=label_transform2, 439 transform=transform, 440 n_samples=n_samples, 441 sampler=sampler, 442 dtype=dtype, 443 label_dtype=label_dtype, 444 with_padding=with_padding, 445 with_channels=with_channels, 446 pre_label_transform=pre_label_transform, 447 ) 448 449 return ds
Get data set for training a segmentation network.
See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the dataset. If None, the length will be set to
len(raw_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- pre_label_transform: Transformation applied to the label data of a chosen random sample,
before applying the sample validity via the
sampler.
Returns:
The torch dataset.
467def default_segmentation_trainer( 468 name: str, 469 model: torch.nn.Module, 470 train_loader: DataLoader, 471 val_loader: DataLoader, 472 loss: Optional[torch.nn.Module] = None, 473 metric: Optional[Callable] = None, 474 learning_rate: float = 1e-3, 475 device: Optional[Union[str, torch.device]] = None, 476 log_image_interval: int = 100, 477 mixed_precision: bool = True, 478 early_stopping: Optional[int] = None, 479 logger=TensorboardLogger, 480 logger_kwargs: Optional[Dict[str, Any]] = None, 481 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 482 optimizer_kwargs: Dict[str, Any] = {}, 483 trainer_class=DefaultTrainer, 484 id_: Optional[str] = None, 485 save_root: Optional[str] = None, 486 compile_model: Optional[Union[bool, str]] = None, 487 rank: Optional[int] = None, 488): 489 """Get a trainer for a segmentation network. 490 491 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 492 By default, it uses the dice score as loss and metric. 493 This can be changed by passing arguments for `loss` and/or `metric`. 494 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 495 496 Here's an example for training a 2D U-Net with this function: 497 ```python 498 import torch_em 499 from torch_em.model import UNet2d 500 from torch_em.data.datasets.light_microscopy import get_dsb_loader 501 502 # The training data will be downloaded to this location. 503 data_root = "/path/to/save/the/training/data" 504 patch_shape = (256, 256) 505 trainer = default_segmentation_trainer( 506 name="unet-training" 507 model=UNet2d(in_channels=1, out_channels=1) 508 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 509 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 510 ) 511 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 512 ``` 513 514 Args: 515 name: The name of the checkpoint that will be created by the trainer. 516 model: The model to train. 517 train_loader: The data loader containing the training data. 518 val_loader: The data loader containing the validation data. 519 loss: The loss function for training. 520 metric: The metric for validation. 521 learning_rate: The initial learning rate for the AdamW optimizer. 522 device: The torch device to use for training. If None, will use a GPU if available. 523 log_image_interval: The interval for saving images during logging, in training iterations. 524 mixed_precision: Whether to train with mixed precision. 525 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 526 logger: The logger class. Will be instantiated for logging. 527 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 528 logger_kwargs: The keyword arguments for the logger class. 529 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 530 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 531 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 532 but can be set to a custom trainer class to enable custom training procedures. 533 id_: Unique identifier for the trainer. If None then `name` will be used. 534 save_root: The root folder for saving the checkpoint and logs. 535 compile_model: Whether to compile the model before training. 536 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 537 538 Returns: 539 The trainer. 540 """ 541 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 542 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 543 544 loss = DiceLoss() if loss is None else loss 545 metric = DiceLoss() if metric is None else metric 546 547 if device is None: 548 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 549 else: 550 device = torch.device(device) 551 552 # CPU does not support mixed precision training. 553 if device.type == "cpu": 554 mixed_precision = False 555 556 return trainer_class( 557 name=name, 558 model=model, 559 train_loader=train_loader, 560 val_loader=val_loader, 561 loss=loss, 562 metric=metric, 563 optimizer=optimizer, 564 device=device, 565 lr_scheduler=scheduler, 566 mixed_precision=mixed_precision, 567 early_stopping=early_stopping, 568 log_image_interval=log_image_interval, 569 logger=logger, 570 logger_kwargs=logger_kwargs, 571 id_=id_, 572 save_root=save_root, 573 compile_model=compile_model, 574 rank=rank, 575 )
Get a trainer for a segmentation network.
It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau.
By default, it uses the dice score as loss and metric.
This can be changed by passing arguments for loss and/or metric.
See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.
Here's an example for training a 2D U-Net with this function:
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
name="unet-training"
model=UNet2d(in_channels=1, out_channels=1)
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- model: The model to train.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- loss: The loss function for training.
- metric: The metric for validation.
- learning_rate: The initial learning rate for the AdamW optimizer.
- device: The torch device to use for training. If None, will use a GPU if available.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger. - logger_kwargs: The keyword arguments for the logger class.
- scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
- optimizer_kwargs: The keyword arguments for the AdamW optimizer.
- trainer_class: The trainer class. Uses
torch_em.trainer.DefaultTrainerby default, but can be set to a custom trainer class to enable custom training procedures. - id_: Unique identifier for the trainer. If None then
namewill be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_trainingfor details.
Returns:
The trainer.