torch_em.segmentation
1import os 2from glob import glob 3from typing import Any, Dict, Optional, Union, Tuple, List, Callable 4 5import numpy as np 6import torch 7import torch.utils.data 8from torch.utils.data import DataLoader 9 10from .loss import DiceLoss 11from .util import load_data 12from .trainer import DefaultTrainer 13from .trainer.tensorboard_logger import TensorboardLogger 14from .transform import get_augmentations, get_raw_transform 15from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset, TensorDataset 16 17 18# TODO add a heuristic to estimate this from the number of epochs 19DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5} 20"""@private 21""" 22 23 24# 25# convenience functions for segmentation loaders 26# 27 28# TODO implement balanced and make it the default 29# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"): 30def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 31 """@private 32 """ 33 assert split in ("balanced", "uniform") 34 n_datasets = len(raw_paths) 35 if split == "uniform": 36 # even distribution of samples to datasets 37 samples_per_ds = n_samples // n_datasets 38 divider = n_samples % n_datasets 39 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 40 else: 41 # distribution of samples to dataset based on the dataset lens 42 raise NotImplementedError 43 44 45def check_paths(raw_paths, label_paths): 46 """@private 47 """ 48 if not isinstance(raw_paths, type(label_paths)): 49 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 50 51 # This is a tensor dataset and we don't need to verify the paths. 52 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (torch.Tensor, np.ndarray)): 53 return 54 55 def _check_path(path): 56 if isinstance(path, str): 57 if not os.path.exists(path): 58 raise ValueError(f"Could not find path {path}") 59 else: 60 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 61 for per_path in path: 62 if not os.path.exists(per_path): 63 raise ValueError(f"Could not find path {per_path}") 64 65 if isinstance(raw_paths, str): 66 _check_path(raw_paths) 67 _check_path(label_paths) 68 else: 69 if len(raw_paths) != len(label_paths): 70 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 71 for rp, lp in zip(raw_paths, label_paths): 72 _check_path(rp) 73 _check_path(lp) 74 75 76# Check if we can load the data as SegmentationDataset. 77def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 78 """@private 79 """ 80 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)): 81 if not all(isinstance(rp, (np.ndarray, torch.Tensor)) for rp in raw_paths): 82 raise ValueError("Inconsistent raw data") 83 if not all(isinstance(lp, (np.ndarray, torch.Tensor)) for lp in label_paths): 84 raise ValueError("Inconsistent label data") 85 return False 86 87 def _can_open(path, key): 88 try: 89 load_data(path, key) 90 return True 91 except Exception: 92 return False 93 94 if isinstance(raw_paths, str): 95 can_open_raw = _can_open(raw_paths, raw_key) 96 can_open_label = _can_open(label_paths, label_key) 97 else: 98 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 99 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 100 raise ValueError("Inconsistent raw data") 101 can_open_raw = can_open_raw[0] 102 103 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 104 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 105 raise ValueError("Inconsistent label data") 106 can_open_label = can_open_label[0] 107 108 if can_open_raw != can_open_label: 109 raise ValueError("Inconsistent raw and label data") 110 111 return can_open_raw 112 113 114def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs): 115 rois = kwargs.pop("rois", None) 116 if isinstance(raw_paths, str): 117 if rois is not None: 118 assert isinstance(rois, (tuple, slice)) 119 if isinstance(rois, tuple): 120 assert all(isinstance(roi, slice) for roi in rois) 121 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) 122 else: 123 assert len(raw_paths) > 0 124 if rois is not None: 125 assert len(rois) == len(label_paths) 126 assert all(isinstance(roi, tuple) for roi in rois), f"{rois}" 127 n_samples = kwargs.pop("n_samples", None) 128 129 samples_per_ds = ( 130 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 131 ) 132 ds = [] 133 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 134 roi = None if rois is None else rois[i] 135 dset = SegmentationDataset( 136 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs 137 ) 138 ds.append(dset) 139 ds = ConcatDataset(*ds) 140 return ds 141 142 143def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, with_channels, **kwargs): 144 if isinstance(raw_paths[0], (torch.Tensor, np.ndarray)): 145 assert raw_key is None and label_key is None 146 assert roi is None 147 return TensorDataset(raw_paths, label_paths, with_channels=with_channels, **kwargs) 148 149 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 150 rpath = glob(os.path.join(rpath, rkey)) 151 rpath.sort() 152 if len(rpath) == 0: 153 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 154 155 lpath = glob(os.path.join(lpath, lkey)) 156 lpath.sort() 157 if len(rpath) != len(lpath): 158 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 159 160 if this_roi is not None: 161 rpath, lpath = rpath[roi], lpath[roi] 162 163 return rpath, lpath 164 165 patch_shape = kwargs.pop("patch_shape") 166 if patch_shape is not None: 167 if len(patch_shape) == 3: 168 if patch_shape[0] != 1: 169 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 170 patch_shape = patch_shape[1:] 171 assert len(patch_shape) == 2 172 173 if isinstance(raw_paths, str): 174 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 175 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 176 177 elif raw_key is None: 178 assert label_key is None 179 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 180 assert len(raw_paths) == len(label_paths) 181 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 182 183 else: 184 ds = [] 185 n_samples = kwargs.pop("n_samples", None) 186 samples_per_ds = ( 187 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 188 ) 189 if roi is None: 190 roi = len(raw_paths) * [None] 191 assert len(roi) == len(raw_paths) 192 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 193 print(raw_path, label_path, this_roi) 194 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 195 dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs) 196 ds.append(dset) 197 ds = ConcatDataset(*ds) 198 199 return ds 200 201 202def _get_default_transform(path, key, is_seg_dataset, ndim): 203 if is_seg_dataset and ndim is None: 204 shape = load_data(path, key).shape 205 if len(shape) == 2: 206 ndim = 2 207 else: 208 # heuristics to figure out whether to use default 3d 209 # or default anisotropic augmentations 210 ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 211 212 elif is_seg_dataset and ndim is not None: 213 pass 214 215 else: 216 ndim = 2 217 218 return get_augmentations(ndim) 219 220 221def default_segmentation_loader( 222 raw_paths: Union[List[Any], str, os.PathLike], 223 raw_key: Optional[str], 224 label_paths: Union[List[Any], str, os.PathLike], 225 label_key: Optional[str], 226 batch_size: int, 227 patch_shape: Tuple[int, ...], 228 label_transform: Optional[Callable] = None, 229 label_transform2: Optional[Callable] = None, 230 raw_transform: Optional[Callable] = None, 231 transform: Optional[Callable] = None, 232 dtype: torch.device = torch.float32, 233 label_dtype: torch.device = torch.float32, 234 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 235 n_samples: Optional[int] = None, 236 sampler: Optional[Callable] = None, 237 ndim: Optional[int] = None, 238 is_seg_dataset: Optional[bool] = None, 239 with_channels: bool = False, 240 with_label_channels: bool = False, 241 verify_paths: bool = True, 242 with_padding: bool = True, 243 z_ext: Optional[int] = None, 244 **loader_kwargs, 245) -> torch.utils.data.DataLoader: 246 """Get data loader for training a segmentation network. 247 248 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 249 on the data formats that are supported. 250 251 Args: 252 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 253 This argument also accepts a list of numpy arrays or torch tensors. 254 raw_key: The name of the internal dataset containing the raw data. 255 Set to None for regular image files, numpy arrays, or torch tensors. 256 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 257 This argument also accepts a list of numpy arrays or torch tensors. 258 label_key: The name of the internal dataset containing the raw data. 259 Set to None for regular image files, numpy arrays, or torch tensors. 260 batch_size: The batch size for the data loader. 261 patch_shape: The patch shape for the training samples. 262 label_transform: Transformation applied to the label data of a sample, 263 before applying augmentations via `transform`. 264 label_transform2: Transformation applied to the label data of a sample, 265 after applying augmentations via `transform`. 266 raw_transform: Transformation applied to the raw data of a sample, 267 before applying augmentations via `transform`. 268 transform: Transformation applied to both the raw data and label data of a sample. 269 This can be used to implement data augmentations. 270 dtype: The return data type of the raw data. 271 label_dtype: The return data type of the label data. 272 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 273 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 274 sampler: Sampler for rejecting samples according to a defined criterion. 275 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 276 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 277 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 278 If None, the type of dataset will be derived from the data. 279 with_channels: Whether the raw data has channels. 280 with_label_channels: Whether the label data has channels. 281 verify_paths: Whether to verify all paths before creating the dataset. 282 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 283 z_ext: Extra bounding box for loading the data across z. 284 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 285 286 Returns: 287 The torch data loader. 288 """ 289 ds = default_segmentation_dataset( 290 raw_paths=raw_paths, 291 raw_key=raw_key, 292 label_paths=label_paths, 293 label_key=label_key, 294 patch_shape=patch_shape, 295 label_transform=label_transform, 296 label_transform2=label_transform2, 297 raw_transform=raw_transform, 298 transform=transform, 299 dtype=dtype, 300 label_dtype=label_dtype, 301 rois=rois, 302 n_samples=n_samples, 303 sampler=sampler, 304 ndim=ndim, 305 is_seg_dataset=is_seg_dataset, 306 with_channels=with_channels, 307 with_label_channels=with_label_channels, 308 with_padding=with_padding, 309 z_ext=z_ext, 310 verify_paths=verify_paths, 311 ) 312 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) 313 314 315def default_segmentation_dataset( 316 raw_paths: Union[List[Any], str, os.PathLike], 317 raw_key: Optional[str], 318 label_paths: Union[List[Any], str, os.PathLike], 319 label_key: Optional[str], 320 patch_shape: Tuple[int, ...], 321 label_transform: Optional[Callable] = None, 322 label_transform2: Optional[Callable] = None, 323 raw_transform: Optional[Callable] = None, 324 transform: Optional[Callable] = None, 325 dtype: torch.dtype = torch.float32, 326 label_dtype: torch.dtype = torch.float32, 327 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 328 n_samples: Optional[int] = None, 329 sampler: Optional[Callable] = None, 330 ndim: Optional[int] = None, 331 is_seg_dataset: Optional[bool] = None, 332 with_channels: bool = False, 333 with_label_channels: bool = False, 334 verify_paths: bool = True, 335 with_padding: bool = True, 336 z_ext: Optional[int] = None, 337) -> torch.utils.data.Dataset: 338 """Get data set for training a segmentation network. 339 340 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 341 on the data formats that are supported. 342 343 Args: 344 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 345 This argument also accepts a list of numpy arrays or torch tensors. 346 raw_key: The name of the internal dataset containing the raw data. 347 Set to None for regular image files, numpy arrays, or torch tensors. 348 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 349 This argument also accepts a list of numpy arrays or torch tensors. 350 label_key: The name of the internal dataset containing the raw data. 351 Set to None for regular image files, numpy arrays, or torch tensors. 352 patch_shape: The patch shape for the training samples. 353 label_transform: Transformation applied to the label data of a sample, 354 before applying augmentations via `transform`. 355 label_transform2: Transformation applied to the label data of a sample, 356 after applying augmentations via `transform`. 357 raw_transform: Transformation applied to the raw data of a sample, 358 before applying augmentations via `transform`. 359 transform: Transformation applied to both the raw data and label data of a sample. 360 This can be used to implement data augmentations. 361 dtype: The return data type of the raw data. 362 label_dtype: The return data type of the label data. 363 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 364 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 365 sampler: Sampler for rejecting samples according to a defined criterion. 366 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 367 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 368 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 369 If None, the type of dataset will be derived from the data. 370 with_channels: Whether the raw data has channels. 371 with_label_channels: Whether the label data has channels. 372 verify_paths: Whether to verify all paths before creating the dataset. 373 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 374 z_ext: Extra bounding box for loading the data across z. 375 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 376 377 Returns: 378 The torch dataset. 379 """ 380 if verify_paths: 381 check_paths(raw_paths, label_paths) 382 383 if is_seg_dataset is None: 384 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 385 386 # We always use a raw transform in the convenience function. 387 if raw_transform is None: 388 raw_transform = get_raw_transform() 389 390 # We always use augmentations in the convenience function. 391 if transform is None: 392 transform = _get_default_transform( 393 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 394 ) 395 396 if is_seg_dataset: 397 ds = _load_segmentation_dataset( 398 raw_paths, 399 raw_key, 400 label_paths, 401 label_key, 402 patch_shape=patch_shape, 403 raw_transform=raw_transform, 404 label_transform=label_transform, 405 label_transform2=label_transform2, 406 transform=transform, 407 rois=rois, 408 n_samples=n_samples, 409 sampler=sampler, 410 ndim=ndim, 411 dtype=dtype, 412 label_dtype=label_dtype, 413 with_channels=with_channels, 414 with_label_channels=with_label_channels, 415 with_padding=with_padding, 416 z_ext=z_ext, 417 ) 418 419 else: 420 ds = _load_image_collection_dataset( 421 raw_paths, 422 raw_key, 423 label_paths, 424 label_key, 425 roi=rois, 426 patch_shape=patch_shape, 427 label_transform=label_transform, 428 raw_transform=raw_transform, 429 label_transform2=label_transform2, 430 transform=transform, 431 n_samples=n_samples, 432 sampler=sampler, 433 dtype=dtype, 434 label_dtype=label_dtype, 435 with_padding=with_padding, 436 with_channels=with_channels, 437 ) 438 439 return ds 440 441 442def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader: 443 """@private 444 """ 445 pin_memory = loader_kwargs.pop("pin_memory", True) 446 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs) 447 # monkey patch shuffle attribute to the loader 448 loader.shuffle = loader_kwargs.get("shuffle", False) 449 return loader 450 451 452# 453# convenience functions for segmentation trainers 454# 455 456 457def default_segmentation_trainer( 458 name: str, 459 model: torch.nn.Module, 460 train_loader: DataLoader, 461 val_loader: DataLoader, 462 loss: Optional[torch.nn.Module] = None, 463 metric: Optional[Callable] = None, 464 learning_rate: float = 1e-3, 465 device: Optional[Union[str, torch.device]] = None, 466 log_image_interval: int = 100, 467 mixed_precision: bool = True, 468 early_stopping: Optional[int] = None, 469 logger=TensorboardLogger, 470 logger_kwargs: Optional[Dict[str, Any]] = None, 471 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 472 optimizer_kwargs: Dict[str, Any] = {}, 473 trainer_class=DefaultTrainer, 474 id_: Optional[str] = None, 475 save_root: Optional[str] = None, 476 compile_model: Optional[Union[bool, str]] = None, 477 rank: Optional[int] = None, 478): 479 """Get a trainer for a segmentation network. 480 481 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 482 By default, it uses the dice score as loss and metric. 483 This can be changed by passing arguments for `loss` and/or `metric`. 484 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 485 486 Here's an example for training a 2D U-Net with this function: 487 ```python 488 import torch_em 489 from torch_em.model import UNet2d 490 from torch_em.data.datasets.light_microscopy import get_dsb_loader 491 492 # The training data will be downloaded to this location. 493 data_root = "/path/to/save/the/training/data" 494 patch_shape = (256, 256) 495 trainer = default_segmentation_trainer( 496 name="unet-training" 497 model=UNet2d(in_channels=1, out_channels=1) 498 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 499 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 500 ) 501 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 502 ``` 503 504 Args: 505 name: The name of the checkpoint that will be created by the trainer. 506 model: The model to train. 507 train_loader: The data loader containing the training data. 508 val_loader: The data loader containing the validation data. 509 loss: The loss function for training. 510 metric: The metric for validation. 511 learning_rate: The initial learning rate for the AdamW optimizer. 512 device: The torch device to use for training. If None, will use a GPU if available. 513 log_image_interval: The interval for saving images during logging, in training iterations. 514 mixed_precision: Whether to train with mixed precision. 515 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 516 logger: The logger class. Will be instantiated for logging. 517 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 518 logger_kwargs: The keyword arguments for the logger class. 519 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 520 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 521 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 522 but can be set to a custom trainer class to enable custom training procedures. 523 id_: Unique identifier for the trainer. If None then `name` will be used. 524 save_root: The root folder for saving the checkpoint and logs. 525 compile_model: Whether to compile the model before training. 526 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 527 528 Returns: 529 The trainer. 530 """ 531 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 532 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 533 534 loss = DiceLoss() if loss is None else loss 535 metric = DiceLoss() if metric is None else metric 536 537 if device is None: 538 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 539 else: 540 device = torch.device(device) 541 542 # CPU does not support mixed precision training. 543 if device.type == "cpu": 544 mixed_precision = False 545 546 return trainer_class( 547 name=name, 548 model=model, 549 train_loader=train_loader, 550 val_loader=val_loader, 551 loss=loss, 552 metric=metric, 553 optimizer=optimizer, 554 device=device, 555 lr_scheduler=scheduler, 556 mixed_precision=mixed_precision, 557 early_stopping=early_stopping, 558 log_image_interval=log_image_interval, 559 logger=logger, 560 logger_kwargs=logger_kwargs, 561 id_=id_, 562 save_root=save_root, 563 compile_model=compile_model, 564 rank=rank, 565 )
222def default_segmentation_loader( 223 raw_paths: Union[List[Any], str, os.PathLike], 224 raw_key: Optional[str], 225 label_paths: Union[List[Any], str, os.PathLike], 226 label_key: Optional[str], 227 batch_size: int, 228 patch_shape: Tuple[int, ...], 229 label_transform: Optional[Callable] = None, 230 label_transform2: Optional[Callable] = None, 231 raw_transform: Optional[Callable] = None, 232 transform: Optional[Callable] = None, 233 dtype: torch.device = torch.float32, 234 label_dtype: torch.device = torch.float32, 235 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 236 n_samples: Optional[int] = None, 237 sampler: Optional[Callable] = None, 238 ndim: Optional[int] = None, 239 is_seg_dataset: Optional[bool] = None, 240 with_channels: bool = False, 241 with_label_channels: bool = False, 242 verify_paths: bool = True, 243 with_padding: bool = True, 244 z_ext: Optional[int] = None, 245 **loader_kwargs, 246) -> torch.utils.data.DataLoader: 247 """Get data loader for training a segmentation network. 248 249 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 250 on the data formats that are supported. 251 252 Args: 253 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 254 This argument also accepts a list of numpy arrays or torch tensors. 255 raw_key: The name of the internal dataset containing the raw data. 256 Set to None for regular image files, numpy arrays, or torch tensors. 257 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 258 This argument also accepts a list of numpy arrays or torch tensors. 259 label_key: The name of the internal dataset containing the raw data. 260 Set to None for regular image files, numpy arrays, or torch tensors. 261 batch_size: The batch size for the data loader. 262 patch_shape: The patch shape for the training samples. 263 label_transform: Transformation applied to the label data of a sample, 264 before applying augmentations via `transform`. 265 label_transform2: Transformation applied to the label data of a sample, 266 after applying augmentations via `transform`. 267 raw_transform: Transformation applied to the raw data of a sample, 268 before applying augmentations via `transform`. 269 transform: Transformation applied to both the raw data and label data of a sample. 270 This can be used to implement data augmentations. 271 dtype: The return data type of the raw data. 272 label_dtype: The return data type of the label data. 273 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 274 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 275 sampler: Sampler for rejecting samples according to a defined criterion. 276 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 277 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 278 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 279 If None, the type of dataset will be derived from the data. 280 with_channels: Whether the raw data has channels. 281 with_label_channels: Whether the label data has channels. 282 verify_paths: Whether to verify all paths before creating the dataset. 283 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 284 z_ext: Extra bounding box for loading the data across z. 285 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 286 287 Returns: 288 The torch data loader. 289 """ 290 ds = default_segmentation_dataset( 291 raw_paths=raw_paths, 292 raw_key=raw_key, 293 label_paths=label_paths, 294 label_key=label_key, 295 patch_shape=patch_shape, 296 label_transform=label_transform, 297 label_transform2=label_transform2, 298 raw_transform=raw_transform, 299 transform=transform, 300 dtype=dtype, 301 label_dtype=label_dtype, 302 rois=rois, 303 n_samples=n_samples, 304 sampler=sampler, 305 ndim=ndim, 306 is_seg_dataset=is_seg_dataset, 307 with_channels=with_channels, 308 with_label_channels=with_label_channels, 309 with_padding=with_padding, 310 z_ext=z_ext, 311 verify_paths=verify_paths, 312 ) 313 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get data loader for training a segmentation network.
See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the underlying dataset. If None, the length will be set to
len(raw_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder.
Returns:
The torch data loader.
316def default_segmentation_dataset( 317 raw_paths: Union[List[Any], str, os.PathLike], 318 raw_key: Optional[str], 319 label_paths: Union[List[Any], str, os.PathLike], 320 label_key: Optional[str], 321 patch_shape: Tuple[int, ...], 322 label_transform: Optional[Callable] = None, 323 label_transform2: Optional[Callable] = None, 324 raw_transform: Optional[Callable] = None, 325 transform: Optional[Callable] = None, 326 dtype: torch.dtype = torch.float32, 327 label_dtype: torch.dtype = torch.float32, 328 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 329 n_samples: Optional[int] = None, 330 sampler: Optional[Callable] = None, 331 ndim: Optional[int] = None, 332 is_seg_dataset: Optional[bool] = None, 333 with_channels: bool = False, 334 with_label_channels: bool = False, 335 verify_paths: bool = True, 336 with_padding: bool = True, 337 z_ext: Optional[int] = None, 338) -> torch.utils.data.Dataset: 339 """Get data set for training a segmentation network. 340 341 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 342 on the data formats that are supported. 343 344 Args: 345 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 346 This argument also accepts a list of numpy arrays or torch tensors. 347 raw_key: The name of the internal dataset containing the raw data. 348 Set to None for regular image files, numpy arrays, or torch tensors. 349 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 350 This argument also accepts a list of numpy arrays or torch tensors. 351 label_key: The name of the internal dataset containing the raw data. 352 Set to None for regular image files, numpy arrays, or torch tensors. 353 patch_shape: The patch shape for the training samples. 354 label_transform: Transformation applied to the label data of a sample, 355 before applying augmentations via `transform`. 356 label_transform2: Transformation applied to the label data of a sample, 357 after applying augmentations via `transform`. 358 raw_transform: Transformation applied to the raw data of a sample, 359 before applying augmentations via `transform`. 360 transform: Transformation applied to both the raw data and label data of a sample. 361 This can be used to implement data augmentations. 362 dtype: The return data type of the raw data. 363 label_dtype: The return data type of the label data. 364 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 365 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 366 sampler: Sampler for rejecting samples according to a defined criterion. 367 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 368 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 369 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 370 If None, the type of dataset will be derived from the data. 371 with_channels: Whether the raw data has channels. 372 with_label_channels: Whether the label data has channels. 373 verify_paths: Whether to verify all paths before creating the dataset. 374 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 375 z_ext: Extra bounding box for loading the data across z. 376 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 377 378 Returns: 379 The torch dataset. 380 """ 381 if verify_paths: 382 check_paths(raw_paths, label_paths) 383 384 if is_seg_dataset is None: 385 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 386 387 # We always use a raw transform in the convenience function. 388 if raw_transform is None: 389 raw_transform = get_raw_transform() 390 391 # We always use augmentations in the convenience function. 392 if transform is None: 393 transform = _get_default_transform( 394 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 395 ) 396 397 if is_seg_dataset: 398 ds = _load_segmentation_dataset( 399 raw_paths, 400 raw_key, 401 label_paths, 402 label_key, 403 patch_shape=patch_shape, 404 raw_transform=raw_transform, 405 label_transform=label_transform, 406 label_transform2=label_transform2, 407 transform=transform, 408 rois=rois, 409 n_samples=n_samples, 410 sampler=sampler, 411 ndim=ndim, 412 dtype=dtype, 413 label_dtype=label_dtype, 414 with_channels=with_channels, 415 with_label_channels=with_label_channels, 416 with_padding=with_padding, 417 z_ext=z_ext, 418 ) 419 420 else: 421 ds = _load_image_collection_dataset( 422 raw_paths, 423 raw_key, 424 label_paths, 425 label_key, 426 roi=rois, 427 patch_shape=patch_shape, 428 label_transform=label_transform, 429 raw_transform=raw_transform, 430 label_transform2=label_transform2, 431 transform=transform, 432 n_samples=n_samples, 433 sampler=sampler, 434 dtype=dtype, 435 label_dtype=label_dtype, 436 with_padding=with_padding, 437 with_channels=with_channels, 438 ) 439 440 return ds
Get data set for training a segmentation network.
See torch_em.data.SegmentationDataset and torch_em.data.ImageCollectionDataset for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. This argument also accepts a list of numpy arrays or torch tensors.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files, numpy arrays, or torch tensors.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform. - raw_transform: Transformation applied to the raw data of a sample,
before applying augmentations via
transform. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the dataset. If None, the length will be set to
len(raw_paths). - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shapeif their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder.
Returns:
The torch dataset.
458def default_segmentation_trainer( 459 name: str, 460 model: torch.nn.Module, 461 train_loader: DataLoader, 462 val_loader: DataLoader, 463 loss: Optional[torch.nn.Module] = None, 464 metric: Optional[Callable] = None, 465 learning_rate: float = 1e-3, 466 device: Optional[Union[str, torch.device]] = None, 467 log_image_interval: int = 100, 468 mixed_precision: bool = True, 469 early_stopping: Optional[int] = None, 470 logger=TensorboardLogger, 471 logger_kwargs: Optional[Dict[str, Any]] = None, 472 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 473 optimizer_kwargs: Dict[str, Any] = {}, 474 trainer_class=DefaultTrainer, 475 id_: Optional[str] = None, 476 save_root: Optional[str] = None, 477 compile_model: Optional[Union[bool, str]] = None, 478 rank: Optional[int] = None, 479): 480 """Get a trainer for a segmentation network. 481 482 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 483 By default, it uses the dice score as loss and metric. 484 This can be changed by passing arguments for `loss` and/or `metric`. 485 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 486 487 Here's an example for training a 2D U-Net with this function: 488 ```python 489 import torch_em 490 from torch_em.model import UNet2d 491 from torch_em.data.datasets.light_microscopy import get_dsb_loader 492 493 # The training data will be downloaded to this location. 494 data_root = "/path/to/save/the/training/data" 495 patch_shape = (256, 256) 496 trainer = default_segmentation_trainer( 497 name="unet-training" 498 model=UNet2d(in_channels=1, out_channels=1) 499 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 500 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 501 ) 502 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 503 ``` 504 505 Args: 506 name: The name of the checkpoint that will be created by the trainer. 507 model: The model to train. 508 train_loader: The data loader containing the training data. 509 val_loader: The data loader containing the validation data. 510 loss: The loss function for training. 511 metric: The metric for validation. 512 learning_rate: The initial learning rate for the AdamW optimizer. 513 device: The torch device to use for training. If None, will use a GPU if available. 514 log_image_interval: The interval for saving images during logging, in training iterations. 515 mixed_precision: Whether to train with mixed precision. 516 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 517 logger: The logger class. Will be instantiated for logging. 518 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 519 logger_kwargs: The keyword arguments for the logger class. 520 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 521 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 522 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 523 but can be set to a custom trainer class to enable custom training procedures. 524 id_: Unique identifier for the trainer. If None then `name` will be used. 525 save_root: The root folder for saving the checkpoint and logs. 526 compile_model: Whether to compile the model before training. 527 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 528 529 Returns: 530 The trainer. 531 """ 532 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 533 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 534 535 loss = DiceLoss() if loss is None else loss 536 metric = DiceLoss() if metric is None else metric 537 538 if device is None: 539 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 540 else: 541 device = torch.device(device) 542 543 # CPU does not support mixed precision training. 544 if device.type == "cpu": 545 mixed_precision = False 546 547 return trainer_class( 548 name=name, 549 model=model, 550 train_loader=train_loader, 551 val_loader=val_loader, 552 loss=loss, 553 metric=metric, 554 optimizer=optimizer, 555 device=device, 556 lr_scheduler=scheduler, 557 mixed_precision=mixed_precision, 558 early_stopping=early_stopping, 559 log_image_interval=log_image_interval, 560 logger=logger, 561 logger_kwargs=logger_kwargs, 562 id_=id_, 563 save_root=save_root, 564 compile_model=compile_model, 565 rank=rank, 566 )
Get a trainer for a segmentation network.
It creates a torch.optim.AdamW optimizer and learning rate scheduler that reduces the learning rate on plateau.
By default, it uses the dice score as loss and metric.
This can be changed by passing arguments for loss and/or metric.
See torch_em.trainer.DefaultTrainer for additional details on how to configure and use the trainer.
Here's an example for training a 2D U-Net with this function:
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
name="unet-training"
model=UNet2d(in_channels=1, out_channels=1)
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- model: The model to train.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- loss: The loss function for training.
- metric: The metric for validation.
- learning_rate: The initial learning rate for the AdamW optimizer.
- device: The torch device to use for training. If None, will use a GPU if available.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger. - logger_kwargs: The keyword arguments for the logger class.
- scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
- optimizer_kwargs: The keyword arguments for the AdamW optimizer.
- trainer_class: The trainer class. Uses
torch_em.trainer.DefaultTrainerby default, but can be set to a custom trainer class to enable custom training procedures. - id_: Unique identifier for the trainer. If None then
namewill be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_trainingfor details.
Returns:
The trainer.