torch_em.segmentation
1import os 2from glob import glob 3from typing import Any, Dict, Optional, Union, Tuple, List, Callable 4 5import torch 6import torch.utils.data 7from torch.utils.data import DataLoader 8 9from .loss import DiceLoss 10from .util import load_data 11from .trainer import DefaultTrainer 12from .trainer.tensorboard_logger import TensorboardLogger 13from .transform import get_augmentations, get_raw_transform 14from .data import ConcatDataset, ImageCollectionDataset, SegmentationDataset 15 16 17# TODO add a heuristic to estimate this from the number of epochs 18DEFAULT_SCHEDULER_KWARGS = {"mode": "min", "factor": 0.5, "patience": 5} 19"""@private 20""" 21 22 23# 24# convenience functions for segmentation loaders 25# 26 27# TODO implement balanced and make it the default 28# def samples_to_datasets(n_samples, raw_paths, raw_key, split="balanced"): 29def samples_to_datasets(n_samples, raw_paths, raw_key, split="uniform"): 30 """@private 31 """ 32 assert split in ("balanced", "uniform") 33 n_datasets = len(raw_paths) 34 if split == "uniform": 35 # even distribution of samples to datasets 36 samples_per_ds = n_samples // n_datasets 37 divider = n_samples % n_datasets 38 return [samples_per_ds + 1 if ii < divider else samples_per_ds for ii in range(n_datasets)] 39 else: 40 # distribution of samples to dataset based on the dataset lens 41 raise NotImplementedError 42 43 44def check_paths(raw_paths, label_paths): 45 """@private 46 """ 47 if not isinstance(raw_paths, type(label_paths)): 48 raise ValueError(f"Expect raw and label paths of same type, got {type(raw_paths)}, {type(label_paths)}") 49 50 def _check_path(path): 51 if isinstance(path, str): 52 if not os.path.exists(path): 53 raise ValueError(f"Could not find path {path}") 54 else: 55 # check for single path or multiple paths (for same volume - supports multi-modal inputs) 56 for per_path in path: 57 if not os.path.exists(per_path): 58 raise ValueError(f"Could not find path {per_path}") 59 60 if isinstance(raw_paths, str): 61 _check_path(raw_paths) 62 _check_path(label_paths) 63 else: 64 if len(raw_paths) != len(label_paths): 65 raise ValueError(f"Expect same number of raw and label paths, got {len(raw_paths)}, {len(label_paths)}") 66 for rp, lp in zip(raw_paths, label_paths): 67 _check_path(rp) 68 _check_path(lp) 69 70 71# Check if we can load the data as SegmentationDataset. 72def is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key): 73 """@private 74 """ 75 def _can_open(path, key): 76 try: 77 load_data(path, key) 78 return True 79 except Exception: 80 return False 81 82 if isinstance(raw_paths, str): 83 can_open_raw = _can_open(raw_paths, raw_key) 84 can_open_label = _can_open(label_paths, label_key) 85 else: 86 can_open_raw = [_can_open(rp, raw_key) for rp in raw_paths] 87 if not can_open_raw.count(can_open_raw[0]) == len(can_open_raw): 88 raise ValueError("Inconsistent raw data") 89 can_open_raw = can_open_raw[0] 90 91 can_open_label = [_can_open(lp, label_key) for lp in label_paths] 92 if not can_open_label.count(can_open_label[0]) == len(can_open_label): 93 raise ValueError("Inconsistent label data") 94 can_open_label = can_open_label[0] 95 96 if can_open_raw != can_open_label: 97 raise ValueError("Inconsistent raw and label data") 98 99 return can_open_raw 100 101 102def _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs): 103 rois = kwargs.pop("rois", None) 104 if isinstance(raw_paths, str): 105 if rois is not None: 106 assert isinstance(rois, (tuple, slice)) 107 if isinstance(rois, tuple): 108 assert all(isinstance(roi, slice) for roi in rois) 109 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) 110 else: 111 assert len(raw_paths) > 0 112 if rois is not None: 113 assert len(rois) == len(label_paths) 114 assert all(isinstance(roi, tuple) for roi in rois), f"{rois}" 115 n_samples = kwargs.pop("n_samples", None) 116 117 samples_per_ds = ( 118 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 119 ) 120 ds = [] 121 for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): 122 roi = None if rois is None else rois[i] 123 dset = SegmentationDataset( 124 raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs 125 ) 126 ds.append(dset) 127 ds = ConcatDataset(*ds) 128 return ds 129 130 131def _load_image_collection_dataset(raw_paths, raw_key, label_paths, label_key, roi, **kwargs): 132 def _get_paths(rpath, rkey, lpath, lkey, this_roi): 133 rpath = glob(os.path.join(rpath, rkey)) 134 rpath.sort() 135 if len(rpath) == 0: 136 raise ValueError(f"Could not find any images for pattern {os.path.join(rpath, rkey)}") 137 138 lpath = glob(os.path.join(lpath, lkey)) 139 lpath.sort() 140 if len(rpath) != len(lpath): 141 raise ValueError(f"Expect same number of raw and label images, got {len(rpath)}, {len(lpath)}") 142 143 if this_roi is not None: 144 rpath, lpath = rpath[roi], lpath[roi] 145 146 return rpath, lpath 147 148 patch_shape = kwargs.pop("patch_shape") 149 if patch_shape is not None: 150 if len(patch_shape) == 3: 151 if patch_shape[0] != 1: 152 raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") 153 patch_shape = patch_shape[1:] 154 assert len(patch_shape) == 2 155 156 if isinstance(raw_paths, str): 157 raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) 158 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 159 160 elif raw_key is None: 161 assert label_key is None 162 assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) 163 assert len(raw_paths) == len(label_paths) 164 ds = ImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape, **kwargs) 165 166 else: 167 ds = [] 168 n_samples = kwargs.pop("n_samples", None) 169 samples_per_ds = ( 170 [None] * len(raw_paths) if n_samples is None else samples_to_datasets(n_samples, raw_paths, raw_key) 171 ) 172 if roi is None: 173 roi = len(raw_paths) * [None] 174 assert len(roi) == len(raw_paths) 175 for i, (raw_path, label_path, this_roi) in enumerate(zip(raw_paths, label_paths, roi)): 176 print(raw_path, label_path, this_roi) 177 rpath, lpath = _get_paths(raw_path, raw_key, label_path, label_key, this_roi) 178 dset = ImageCollectionDataset(rpath, lpath, patch_shape=patch_shape, n_samples=samples_per_ds[i], **kwargs) 179 ds.append(dset) 180 ds = ConcatDataset(*ds) 181 182 return ds 183 184 185def _get_default_transform(path, key, is_seg_dataset, ndim): 186 if is_seg_dataset and ndim is None: 187 shape = load_data(path, key).shape 188 if len(shape) == 2: 189 ndim = 2 190 else: 191 # heuristics to figure out whether to use default 3d 192 # or default anisotropic augmentations 193 ndim = "anisotropic" if shape[0] < shape[1] // 2 else 3 194 195 elif is_seg_dataset and ndim is not None: 196 pass 197 198 else: 199 ndim = 2 200 201 return get_augmentations(ndim) 202 203 204def default_segmentation_loader( 205 raw_paths: Union[List[Any], str, os.PathLike], 206 raw_key: Optional[str], 207 label_paths: Union[List[Any], str, os.PathLike], 208 label_key: Optional[str], 209 batch_size: int, 210 patch_shape: Tuple[int, ...], 211 label_transform: Optional[Callable] = None, 212 label_transform2: Optional[Callable] = None, 213 raw_transform: Optional[Callable] = None, 214 transform: Optional[Callable] = None, 215 dtype: torch.device = torch.float32, 216 label_dtype: torch.device = torch.float32, 217 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 218 n_samples: Optional[int] = None, 219 sampler: Optional[Callable] = None, 220 ndim: Optional[int] = None, 221 is_seg_dataset: Optional[bool] = None, 222 with_channels: bool = False, 223 with_label_channels: bool = False, 224 verify_paths: bool = True, 225 with_padding: bool = True, 226 z_ext: Optional[int] = None, 227 **loader_kwargs, 228) -> torch.utils.data.DataLoader: 229 """Get data loader for training a segmentation network. 230 231 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 232 on the data formats that are supported. 233 234 Args: 235 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 236 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 237 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 238 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 239 batch_size: The batch size for the data loader. 240 patch_shape: The patch shape for the training samples. 241 label_transform: Transformation applied to the label data of a sample, 242 before applying augmentations via `transform`. 243 label_transform2: Transformation applied to the label data of a sample, 244 after applying augmentations via `transform`. 245 transform: Transformation applied to both the raw data and label data of a sample. 246 This can be used to implement data augmentations. 247 dtype: The return data type of the raw data. 248 label_dtype: The return data type of the label data. 249 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 250 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 251 sampler: Sampler for rejecting samples according to a defined criterion. 252 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 253 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 254 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 255 If None, the type of dataset will be derived from the data. 256 with_channels: Whether the raw data has channels. 257 with_label_channels: Whether the label data has channels. 258 verify_paths: Whether to verify all paths before creating the dataset. 259 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 260 z_ext: Extra bounding box for loading the data across z. 261 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 262 263 Returns: 264 The torch data loader. 265 """ 266 ds = default_segmentation_dataset( 267 raw_paths=raw_paths, 268 raw_key=raw_key, 269 label_paths=label_paths, 270 label_key=label_key, 271 patch_shape=patch_shape, 272 label_transform=label_transform, 273 label_transform2=label_transform2, 274 raw_transform=raw_transform, 275 transform=transform, 276 dtype=dtype, 277 label_dtype=label_dtype, 278 rois=rois, 279 n_samples=n_samples, 280 sampler=sampler, 281 ndim=ndim, 282 is_seg_dataset=is_seg_dataset, 283 with_channels=with_channels, 284 with_label_channels=with_label_channels, 285 with_padding=with_padding, 286 z_ext=z_ext, 287 verify_paths=verify_paths, 288 ) 289 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) 290 291 292def default_segmentation_dataset( 293 raw_paths: Union[List[Any], str, os.PathLike], 294 raw_key: Optional[str], 295 label_paths: Union[List[Any], str, os.PathLike], 296 label_key: Optional[str], 297 patch_shape: Tuple[int, ...], 298 label_transform: Optional[Callable] = None, 299 label_transform2: Optional[Callable] = None, 300 raw_transform: Optional[Callable] = None, 301 transform: Optional[Callable] = None, 302 dtype: torch.dtype = torch.float32, 303 label_dtype: torch.dtype = torch.float32, 304 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 305 n_samples: Optional[int] = None, 306 sampler: Optional[Callable] = None, 307 ndim: Optional[int] = None, 308 is_seg_dataset: Optional[bool] = None, 309 with_channels: bool = False, 310 with_label_channels: bool = False, 311 verify_paths: bool = True, 312 with_padding: bool = True, 313 z_ext: Optional[int] = None, 314) -> torch.utils.data.Dataset: 315 """Get data set for training a segmentation network. 316 317 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 318 on the data formats that are supported. 319 320 Args: 321 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 322 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 323 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 324 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 325 patch_shape: The patch shape for the training samples. 326 label_transform: Transformation applied to the label data of a sample, 327 before applying augmentations via `transform`. 328 label_transform2: Transformation applied to the label data of a sample, 329 after applying augmentations via `transform`. 330 transform: Transformation applied to both the raw data and label data of a sample. 331 This can be used to implement data augmentations. 332 dtype: The return data type of the raw data. 333 label_dtype: The return data type of the label data. 334 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 335 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 336 sampler: Sampler for rejecting samples according to a defined criterion. 337 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 338 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 339 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 340 If None, the type of dataset will be derived from the data. 341 with_channels: Whether the raw data has channels. 342 with_label_channels: Whether the label data has channels. 343 verify_paths: Whether to verify all paths before creating the dataset. 344 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 345 z_ext: Extra bounding box for loading the data across z. 346 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 347 348 Returns: 349 The torch data set. 350 """ 351 if verify_paths: 352 check_paths(raw_paths, label_paths) 353 354 if is_seg_dataset is None: 355 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 356 357 # We always use a raw transform in the convenience function. 358 if raw_transform is None: 359 raw_transform = get_raw_transform() 360 361 # We always use augmentations in the convenience function. 362 if transform is None: 363 transform = _get_default_transform( 364 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 365 ) 366 367 if is_seg_dataset: 368 ds = _load_segmentation_dataset( 369 raw_paths, 370 raw_key, 371 label_paths, 372 label_key, 373 patch_shape=patch_shape, 374 raw_transform=raw_transform, 375 label_transform=label_transform, 376 label_transform2=label_transform2, 377 transform=transform, 378 rois=rois, 379 n_samples=n_samples, 380 sampler=sampler, 381 ndim=ndim, 382 dtype=dtype, 383 label_dtype=label_dtype, 384 with_channels=with_channels, 385 with_label_channels=with_label_channels, 386 with_padding=with_padding, 387 z_ext=z_ext, 388 ) 389 390 else: 391 ds = _load_image_collection_dataset( 392 raw_paths, 393 raw_key, 394 label_paths, 395 label_key, 396 roi=rois, 397 patch_shape=patch_shape, 398 label_transform=label_transform, 399 raw_transform=raw_transform, 400 label_transform2=label_transform2, 401 transform=transform, 402 n_samples=n_samples, 403 sampler=sampler, 404 dtype=dtype, 405 label_dtype=label_dtype, 406 with_padding=with_padding, 407 ) 408 409 return ds 410 411 412def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader: 413 """@private 414 """ 415 pin_memory = loader_kwargs.pop("pin_memory", True) 416 loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs) 417 # monkey patch shuffle attribute to the loader 418 loader.shuffle = loader_kwargs.get("shuffle", False) 419 return loader 420 421 422# 423# convenience functions for segmentation trainers 424# 425 426 427def default_segmentation_trainer( 428 name: str, 429 model: torch.nn.Module, 430 train_loader: DataLoader, 431 val_loader: DataLoader, 432 loss: Optional[torch.nn.Module] = None, 433 metric: Optional[Callable] = None, 434 learning_rate: float = 1e-3, 435 device: Optional[Union[str, torch.device]] = None, 436 log_image_interval: int = 100, 437 mixed_precision: bool = True, 438 early_stopping: Optional[int] = None, 439 logger=TensorboardLogger, 440 logger_kwargs: Optional[Dict[str, Any]] = None, 441 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 442 optimizer_kwargs: Dict[str, Any] = {}, 443 trainer_class=DefaultTrainer, 444 id_: Optional[str] = None, 445 save_root: Optional[str] = None, 446 compile_model: Optional[Union[bool, str]] = None, 447 rank: Optional[int] = None, 448): 449 """Get a trainer for a segmentation network. 450 451 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 452 By default, it uses the dice score as loss and metric. 453 This can be changed by passing arguments for `loss` and/or `metric`. 454 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 455 456 Here's an example for training a 2D U-Net with this function: 457 ```python 458 import torch_em 459 from torch_em.model import UNet2d 460 from torch_em.data.datasets.light_microscopy import get_dsb_loader 461 462 # The training data will be downloaded to this location. 463 data_root = "/path/to/save/the/training/data" 464 patch_shape = (256, 256) 465 trainer = default_segmentation_trainer( 466 name="unet-training" 467 model=UNet2d(in_channels=1, out_channels=1) 468 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 469 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 470 ) 471 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 472 ``` 473 474 Args: 475 name: The name of the checkpoint that will be created by the trainer. 476 model: The model to train. 477 train_loader: The data loader containing the training data. 478 val_loader: The data loader containing the validation data. 479 loss: The loss function for training. 480 metric: The metric for validation. 481 learning_rate: The initial learning rate for the AdamW optimizer. 482 device: The torch device to use for training. If None, will use a GPU if available. 483 log_image_interval: The interval for saving images during logging, in training iterations. 484 mixed_precision: Whether to train with mixed precision. 485 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 486 logger: The logger class. Will be instantiated for logging. 487 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 488 logger_kwargs: The keyword arguments for the logger class. 489 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 490 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 491 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 492 but can be set to a custom trainer class to enable custom training procedures. 493 id_: Unique identifier for the trainer. If None then `name` will be used. 494 save_root: The root folder for saving the checkpoint and logs. 495 compile_model: Whether to compile the model before training. 496 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 497 498 Returns: 499 The trainer. 500 """ 501 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 502 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 503 504 loss = DiceLoss() if loss is None else loss 505 metric = DiceLoss() if metric is None else metric 506 507 if device is None: 508 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 509 else: 510 device = torch.device(device) 511 512 # CPU does not support mixed precision training. 513 if device.type == "cpu": 514 mixed_precision = False 515 516 return trainer_class( 517 name=name, 518 model=model, 519 train_loader=train_loader, 520 val_loader=val_loader, 521 loss=loss, 522 metric=metric, 523 optimizer=optimizer, 524 device=device, 525 lr_scheduler=scheduler, 526 mixed_precision=mixed_precision, 527 early_stopping=early_stopping, 528 log_image_interval=log_image_interval, 529 logger=logger, 530 logger_kwargs=logger_kwargs, 531 id_=id_, 532 save_root=save_root, 533 compile_model=compile_model, 534 rank=rank, 535 )
205def default_segmentation_loader( 206 raw_paths: Union[List[Any], str, os.PathLike], 207 raw_key: Optional[str], 208 label_paths: Union[List[Any], str, os.PathLike], 209 label_key: Optional[str], 210 batch_size: int, 211 patch_shape: Tuple[int, ...], 212 label_transform: Optional[Callable] = None, 213 label_transform2: Optional[Callable] = None, 214 raw_transform: Optional[Callable] = None, 215 transform: Optional[Callable] = None, 216 dtype: torch.device = torch.float32, 217 label_dtype: torch.device = torch.float32, 218 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 219 n_samples: Optional[int] = None, 220 sampler: Optional[Callable] = None, 221 ndim: Optional[int] = None, 222 is_seg_dataset: Optional[bool] = None, 223 with_channels: bool = False, 224 with_label_channels: bool = False, 225 verify_paths: bool = True, 226 with_padding: bool = True, 227 z_ext: Optional[int] = None, 228 **loader_kwargs, 229) -> torch.utils.data.DataLoader: 230 """Get data loader for training a segmentation network. 231 232 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 233 on the data formats that are supported. 234 235 Args: 236 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 237 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 238 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 239 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 240 batch_size: The batch size for the data loader. 241 patch_shape: The patch shape for the training samples. 242 label_transform: Transformation applied to the label data of a sample, 243 before applying augmentations via `transform`. 244 label_transform2: Transformation applied to the label data of a sample, 245 after applying augmentations via `transform`. 246 transform: Transformation applied to both the raw data and label data of a sample. 247 This can be used to implement data augmentations. 248 dtype: The return data type of the raw data. 249 label_dtype: The return data type of the label data. 250 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 251 n_samples: The length of the underlying dataset. If None, the length will be set to `len(raw_paths)`. 252 sampler: Sampler for rejecting samples according to a defined criterion. 253 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 254 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 255 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 256 If None, the type of dataset will be derived from the data. 257 with_channels: Whether the raw data has channels. 258 with_label_channels: Whether the label data has channels. 259 verify_paths: Whether to verify all paths before creating the dataset. 260 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 261 z_ext: Extra bounding box for loading the data across z. 262 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 263 264 Returns: 265 The torch data loader. 266 """ 267 ds = default_segmentation_dataset( 268 raw_paths=raw_paths, 269 raw_key=raw_key, 270 label_paths=label_paths, 271 label_key=label_key, 272 patch_shape=patch_shape, 273 label_transform=label_transform, 274 label_transform2=label_transform2, 275 raw_transform=raw_transform, 276 transform=transform, 277 dtype=dtype, 278 label_dtype=label_dtype, 279 rois=rois, 280 n_samples=n_samples, 281 sampler=sampler, 282 ndim=ndim, 283 is_seg_dataset=is_seg_dataset, 284 with_channels=with_channels, 285 with_label_channels=with_label_channels, 286 with_padding=with_padding, 287 z_ext=z_ext, 288 verify_paths=verify_paths, 289 ) 290 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
Get data loader for training a segmentation network.
See torch_em.data.SegmentationDataset
and torch_em.data.ImageCollectionDataset
for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- batch_size: The batch size for the data loader.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the underlying dataset. If None, the length will be set to
len(raw_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder
.
Returns:
The torch data loader.
293def default_segmentation_dataset( 294 raw_paths: Union[List[Any], str, os.PathLike], 295 raw_key: Optional[str], 296 label_paths: Union[List[Any], str, os.PathLike], 297 label_key: Optional[str], 298 patch_shape: Tuple[int, ...], 299 label_transform: Optional[Callable] = None, 300 label_transform2: Optional[Callable] = None, 301 raw_transform: Optional[Callable] = None, 302 transform: Optional[Callable] = None, 303 dtype: torch.dtype = torch.float32, 304 label_dtype: torch.dtype = torch.float32, 305 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 306 n_samples: Optional[int] = None, 307 sampler: Optional[Callable] = None, 308 ndim: Optional[int] = None, 309 is_seg_dataset: Optional[bool] = None, 310 with_channels: bool = False, 311 with_label_channels: bool = False, 312 verify_paths: bool = True, 313 with_padding: bool = True, 314 z_ext: Optional[int] = None, 315) -> torch.utils.data.Dataset: 316 """Get data set for training a segmentation network. 317 318 See `torch_em.data.SegmentationDataset` and `torch_em.data.ImageCollectionDataset` for details 319 on the data formats that are supported. 320 321 Args: 322 raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths. 323 raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 324 label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths. 325 label_key: The name of the internal dataset containing the raw data. Set to None for regular image files. 326 patch_shape: The patch shape for the training samples. 327 label_transform: Transformation applied to the label data of a sample, 328 before applying augmentations via `transform`. 329 label_transform2: Transformation applied to the label data of a sample, 330 after applying augmentations via `transform`. 331 transform: Transformation applied to both the raw data and label data of a sample. 332 This can be used to implement data augmentations. 333 dtype: The return data type of the raw data. 334 label_dtype: The return data type of the label data. 335 rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area. 336 n_samples: The length of the dataset. If None, the length will be set to `len(raw_paths)`. 337 sampler: Sampler for rejecting samples according to a defined criterion. 338 The sampler must be a callable that accepts the raw data (as numpy arrays) as input. 339 ndim: The spatial dimensionality of the data. If None, will be derived from the raw data. 340 is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. 341 If None, the type of dataset will be derived from the data. 342 with_channels: Whether the raw data has channels. 343 with_label_channels: Whether the label data has channels. 344 verify_paths: Whether to verify all paths before creating the dataset. 345 with_padding: Whether to pad samples to `patch_shape` if their shape is smaller. 346 z_ext: Extra bounding box for loading the data across z. 347 loader_kwargs: Keyword arguments for `torch.utils.data.DataLoder`. 348 349 Returns: 350 The torch data set. 351 """ 352 if verify_paths: 353 check_paths(raw_paths, label_paths) 354 355 if is_seg_dataset is None: 356 is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) 357 358 # We always use a raw transform in the convenience function. 359 if raw_transform is None: 360 raw_transform = get_raw_transform() 361 362 # We always use augmentations in the convenience function. 363 if transform is None: 364 transform = _get_default_transform( 365 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim 366 ) 367 368 if is_seg_dataset: 369 ds = _load_segmentation_dataset( 370 raw_paths, 371 raw_key, 372 label_paths, 373 label_key, 374 patch_shape=patch_shape, 375 raw_transform=raw_transform, 376 label_transform=label_transform, 377 label_transform2=label_transform2, 378 transform=transform, 379 rois=rois, 380 n_samples=n_samples, 381 sampler=sampler, 382 ndim=ndim, 383 dtype=dtype, 384 label_dtype=label_dtype, 385 with_channels=with_channels, 386 with_label_channels=with_label_channels, 387 with_padding=with_padding, 388 z_ext=z_ext, 389 ) 390 391 else: 392 ds = _load_image_collection_dataset( 393 raw_paths, 394 raw_key, 395 label_paths, 396 label_key, 397 roi=rois, 398 patch_shape=patch_shape, 399 label_transform=label_transform, 400 raw_transform=raw_transform, 401 label_transform2=label_transform2, 402 transform=transform, 403 n_samples=n_samples, 404 sampler=sampler, 405 dtype=dtype, 406 label_dtype=label_dtype, 407 with_padding=with_padding, 408 ) 409 410 return ds
Get data set for training a segmentation network.
See torch_em.data.SegmentationDataset
and torch_em.data.ImageCollectionDataset
for details
on the data formats that are supported.
Arguments:
- raw_paths: The file path(s) to the raw data. Can either be a single path or multiple file paths.
- raw_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- label_paths: The file path(s) to the label data. Can either be a single path or multiple file paths.
- label_key: The name of the internal dataset containing the raw data. Set to None for regular image files.
- patch_shape: The patch shape for the training samples.
- label_transform: Transformation applied to the label data of a sample,
before applying augmentations via
transform
. - label_transform2: Transformation applied to the label data of a sample,
after applying augmentations via
transform
. - transform: Transformation applied to both the raw data and label data of a sample. This can be used to implement data augmentations.
- dtype: The return data type of the raw data.
- label_dtype: The return data type of the label data.
- rois: Regions of interest in the data. If given, the data will only be loaded from the corresponding area.
- n_samples: The length of the dataset. If None, the length will be set to
len(raw_paths)
. - sampler: Sampler for rejecting samples according to a defined criterion. The sampler must be a callable that accepts the raw data (as numpy arrays) as input.
- ndim: The spatial dimensionality of the data. If None, will be derived from the raw data.
- is_seg_dataset: Whether this is a segmentation dataset or an image collection dataset. If None, the type of dataset will be derived from the data.
- with_channels: Whether the raw data has channels.
- with_label_channels: Whether the label data has channels.
- verify_paths: Whether to verify all paths before creating the dataset.
- with_padding: Whether to pad samples to
patch_shape
if their shape is smaller. - z_ext: Extra bounding box for loading the data across z.
- loader_kwargs: Keyword arguments for
torch.utils.data.DataLoder
.
Returns:
The torch data set.
428def default_segmentation_trainer( 429 name: str, 430 model: torch.nn.Module, 431 train_loader: DataLoader, 432 val_loader: DataLoader, 433 loss: Optional[torch.nn.Module] = None, 434 metric: Optional[Callable] = None, 435 learning_rate: float = 1e-3, 436 device: Optional[Union[str, torch.device]] = None, 437 log_image_interval: int = 100, 438 mixed_precision: bool = True, 439 early_stopping: Optional[int] = None, 440 logger=TensorboardLogger, 441 logger_kwargs: Optional[Dict[str, Any]] = None, 442 scheduler_kwargs: Dict[str, Any] = DEFAULT_SCHEDULER_KWARGS, 443 optimizer_kwargs: Dict[str, Any] = {}, 444 trainer_class=DefaultTrainer, 445 id_: Optional[str] = None, 446 save_root: Optional[str] = None, 447 compile_model: Optional[Union[bool, str]] = None, 448 rank: Optional[int] = None, 449): 450 """Get a trainer for a segmentation network. 451 452 It creates a `torch.optim.AdamW` optimizer and learning rate scheduler that reduces the learning rate on plateau. 453 By default, it uses the dice score as loss and metric. 454 This can be changed by passing arguments for `loss` and/or `metric`. 455 See `torch_em.trainer.DefaultTrainer` for additional details on how to configure and use the trainer. 456 457 Here's an example for training a 2D U-Net with this function: 458 ```python 459 import torch_em 460 from torch_em.model import UNet2d 461 from torch_em.data.datasets.light_microscopy import get_dsb_loader 462 463 # The training data will be downloaded to this location. 464 data_root = "/path/to/save/the/training/data" 465 patch_shape = (256, 256) 466 trainer = default_segmentation_trainer( 467 name="unet-training" 468 model=UNet2d(in_channels=1, out_channels=1) 469 train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"), 470 val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"), 471 ) 472 trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations. 473 ``` 474 475 Args: 476 name: The name of the checkpoint that will be created by the trainer. 477 model: The model to train. 478 train_loader: The data loader containing the training data. 479 val_loader: The data loader containing the validation data. 480 loss: The loss function for training. 481 metric: The metric for validation. 482 learning_rate: The initial learning rate for the AdamW optimizer. 483 device: The torch device to use for training. If None, will use a GPU if available. 484 log_image_interval: The interval for saving images during logging, in training iterations. 485 mixed_precision: Whether to train with mixed precision. 486 early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used. 487 logger: The logger class. Will be instantiated for logging. 488 By default uses `torch_em.training.tensorboard_logger.TensorboardLogger`. 489 logger_kwargs: The keyword arguments for the logger class. 490 scheduler_kwargs: The keyword arguments for ReduceLROnPlateau. 491 optimizer_kwargs: The keyword arguments for the AdamW optimizer. 492 trainer_class: The trainer class. Uses `torch_em.trainer.DefaultTrainer` by default, 493 but can be set to a custom trainer class to enable custom training procedures. 494 id_: Unique identifier for the trainer. If None then `name` will be used. 495 save_root: The root folder for saving the checkpoint and logs. 496 compile_model: Whether to compile the model before training. 497 rank: Rank argument for distributed training. See `torch_em.multi_gpu_training` for details. 498 499 Returns: 500 The trainer. 501 """ 502 optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) 503 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) 504 505 loss = DiceLoss() if loss is None else loss 506 metric = DiceLoss() if metric is None else metric 507 508 if device is None: 509 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 510 else: 511 device = torch.device(device) 512 513 # CPU does not support mixed precision training. 514 if device.type == "cpu": 515 mixed_precision = False 516 517 return trainer_class( 518 name=name, 519 model=model, 520 train_loader=train_loader, 521 val_loader=val_loader, 522 loss=loss, 523 metric=metric, 524 optimizer=optimizer, 525 device=device, 526 lr_scheduler=scheduler, 527 mixed_precision=mixed_precision, 528 early_stopping=early_stopping, 529 log_image_interval=log_image_interval, 530 logger=logger, 531 logger_kwargs=logger_kwargs, 532 id_=id_, 533 save_root=save_root, 534 compile_model=compile_model, 535 rank=rank, 536 )
Get a trainer for a segmentation network.
It creates a torch.optim.AdamW
optimizer and learning rate scheduler that reduces the learning rate on plateau.
By default, it uses the dice score as loss and metric.
This can be changed by passing arguments for loss
and/or metric
.
See torch_em.trainer.DefaultTrainer
for additional details on how to configure and use the trainer.
Here's an example for training a 2D U-Net with this function:
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets.light_microscopy import get_dsb_loader
# The training data will be downloaded to this location.
data_root = "/path/to/save/the/training/data"
patch_shape = (256, 256)
trainer = default_segmentation_trainer(
name="unet-training"
model=UNet2d(in_channels=1, out_channels=1)
train_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="train"),
val_loader=get_dsb_loader(path=data_root, patch_shape=patch_shape, split="test"),
)
trainer.fit(iterations=int(2.5e4)) # Train for 25.000 iterations.
Arguments:
- name: The name of the checkpoint that will be created by the trainer.
- model: The model to train.
- train_loader: The data loader containing the training data.
- val_loader: The data loader containing the validation data.
- loss: The loss function for training.
- metric: The metric for validation.
- learning_rate: The initial learning rate for the AdamW optimizer.
- device: The torch device to use for training. If None, will use a GPU if available.
- log_image_interval: The interval for saving images during logging, in training iterations.
- mixed_precision: Whether to train with mixed precision.
- early_stopping: The patience for early stopping in epochs. If None, early stopping will not be used.
- logger: The logger class. Will be instantiated for logging.
By default uses
torch_em.training.tensorboard_logger.TensorboardLogger
. - logger_kwargs: The keyword arguments for the logger class.
- scheduler_kwargs: The keyword arguments for ReduceLROnPlateau.
- optimizer_kwargs: The keyword arguments for the AdamW optimizer.
- trainer_class: The trainer class. Uses
torch_em.trainer.DefaultTrainer
by default, but can be set to a custom trainer class to enable custom training procedures. - id_: Unique identifier for the trainer. If None then
name
will be used. - save_root: The root folder for saving the checkpoint and logs.
- compile_model: Whether to compile the model before training.
- rank: Rank argument for distributed training. See
torch_em.multi_gpu_training
for details.
Returns:
The trainer.