torch_em.multi_gpu_training
1import os 2from functools import partial 3from typing import Dict, Any, Optional, Callable 4 5import torch 6import torch.utils.data 7import torch.distributed as dist 8from torch.nn.parallel import DistributedDataParallel 9 10import torch_em 11 12 13def setup(rank, world_size): 14 """@private 15 """ 16 os.environ["MASTER_ADDR"] = "localhost" 17 os.environ["MASTER_PORT"] = "12355" 18 dist.init_process_group("nccl", rank=rank, world_size=world_size) 19 20 21def cleanup(): 22 """@private 23 """ 24 dist.destroy_process_group() 25 26 27def _create_data_loader(ds_callable, ds_kwargs, loader_kwargs, world_size, rank): 28 # Create the dataset. 29 ds = ds_callable(**ds_kwargs) 30 31 # Create the sampler 32 # Set shuffle on the sampler instead of the loader 33 shuffle = loader_kwargs.pop("shuffle", False) 34 sampler = torch.utils.data.distributed.DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=shuffle) 35 36 # Create the loader. 37 loader = torch.utils.data.DataLoader(ds, sampler=sampler, **loader_kwargs) 38 loader.shuffle = shuffle 39 40 return loader 41 42 43class DDP(DistributedDataParallel): 44 """Wrapper for the DistributedDataParallel class. 45 46 Oerrides the `__getattr__` method to handle access from the "model" object module wrapped by DDP. 47 """ 48 def __getattr__(self, name): 49 try: 50 return super().__getattr__(name) 51 except AttributeError: 52 return getattr(self.module, name) 53 54 55def _train_impl( 56 rank: int, 57 world_size: int, 58 model_callable: Callable[[Any], torch.nn.Module], 59 model_kwargs: Dict[str, Any], 60 train_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 61 train_dataset_kwargs: Dict[str, Any], 62 val_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 63 val_dataset_kwargs: Dict[str, Any], 64 loader_kwargs: Dict[str, Any], 65 iterations: int, 66 find_unused_parameters: bool = True, 67 optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None, 68 optimizer_kwargs: Optional[Dict[str, Any]] = None, 69 lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None, 70 lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, 71 trainer_callable: Optional[Callable] = None, 72 **kwargs 73): 74 assert "device" not in kwargs 75 print(f"Running DDP on rank {rank}.") 76 setup(rank, world_size) 77 78 model = model_callable(**model_kwargs).to(rank) 79 ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=find_unused_parameters) 80 81 if optimizer_callable is not None: 82 optimizer = optimizer_callable(model.parameters(), **optimizer_kwargs) 83 kwargs["optimizer"] = optimizer 84 if lr_scheduler_callable is not None: 85 lr_scheduler = lr_scheduler_callable(optimizer, **lr_scheduler_kwargs) 86 kwargs["lr_scheduler"] = lr_scheduler 87 88 train_loader = _create_data_loader(train_dataset_callable, train_dataset_kwargs, loader_kwargs, world_size, rank) 89 val_loader = _create_data_loader(val_dataset_callable, val_dataset_kwargs, loader_kwargs, world_size, rank) 90 91 if trainer_callable is None: 92 trainer_callable = torch_em.default_segmentation_trainer 93 94 trainer = trainer_callable( 95 model=ddp_model, 96 train_loader=train_loader, 97 val_loader=val_loader, 98 device=rank, 99 rank=rank, 100 **kwargs 101 ) 102 trainer.fit(iterations=iterations) 103 104 cleanup() 105 106 107def train_multi_gpu( 108 model_callable: Callable[[Any], torch.nn.Module], 109 model_kwargs: Dict[str, Any], 110 train_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 111 train_dataset_kwargs: Dict[str, Any], 112 val_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 113 val_dataset_kwargs: Dict[str, Any], 114 loader_kwargs: Dict[str, Any], 115 iterations: int, 116 find_unused_parameters: bool = True, 117 optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None, 118 optimizer_kwargs: Optional[Dict[str, Any]] = None, 119 lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None, 120 lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, 121 trainer_callable: Optional[Callable] = None, 122 **kwargs 123) -> None: 124 """Run data parallel training on multiple local GPUs via torch.distributed. 125 126 This function will run training on all available local GPUs in parallel. 127 To use it, the function / classes and keywords for the model and data loaders must be given. 128 Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class 129 may be given, so that they can be instantiated for each training child process. 130 131 Here is an example for training a 2D U-Net on the DSB dataset: 132 ```python 133 from torch_em.model import UNet2d 134 from torch_em.multi_gpu_training import train_multi_gpu 135 from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data 136 137 # Make sure the data is downloaded before starting multi-gpu training. 138 data_root = "/path/to/save/the/training/data" 139 get_dsb_data(data_root, source="reduced", download=True) 140 141 patch_shape = (256, 256) 142 train_multi_gpu( 143 model_callable=UNet2d, 144 model_kwargs={"in_channels": 1, "out_channels": 1}, 145 train_dataset_callable=get_dsb_dataset, 146 train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"}, 147 val_dataset_callable=get_dsb_dataset, 148 val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"}, 149 loader_kwargs={"shuffle": True}, 150 iterations=int(5e4), # Train for 50.000 iterations. 151 ) 152 ``` 153 154 Args: 155 model_callable: Function or class to create the model. 156 model_kwargs: Keyword arguments for `model_callable`. 157 train_dataset_callable: Function or class to create the training dataset. 158 train_dataset_kwargs: Keyword arguments for `train_dataset_callable`. 159 val_dataset_callable: Function or class to create the validation dataset. 160 val_dataset_kwargs: Keyword arguments for `val_dataset_callable`. 161 loader_kwargs: Keyword arguments for the torch data loader. 162 iterations: Number of iterations to train for. 163 find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization. 164 optimizer_callable: Function or class to create the optimizer. 165 optimizer_kwargs: Keyword arguments for `optimizer_callable`. 166 lr_scheduler_callable: Function or class to create the learning rate scheduler. 167 lr_scheduler_kwargs: Keyword arguments for `lr_scheduler_callable`. 168 trainer_callable: Function or class to create the trainer. 169 kwargs: Keyword arguments for `trainer_callable`. 170 """ 171 world_size = torch.cuda.device_count() 172 train = partial( 173 _train_impl, 174 model_callable=model_callable, 175 model_kwargs=model_kwargs, 176 train_dataset_callable=train_dataset_callable, 177 train_dataset_kwargs=train_dataset_kwargs, 178 val_dataset_callable=val_dataset_callable, 179 val_dataset_kwargs=val_dataset_kwargs, 180 loader_kwargs=loader_kwargs, 181 iterations=iterations, 182 find_unused_parameters=find_unused_parameters, 183 optimizer_callable=optimizer_callable, 184 optimizer_kwargs=optimizer_kwargs, 185 lr_scheduler_callable=lr_scheduler_callable, 186 lr_scheduler_kwargs=lr_scheduler_kwargs, 187 trainer_callable=trainer_callable, 188 **kwargs 189 ) 190 torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
class
DDP(torch.nn.parallel.distributed.DistributedDataParallel):
44class DDP(DistributedDataParallel): 45 """Wrapper for the DistributedDataParallel class. 46 47 Oerrides the `__getattr__` method to handle access from the "model" object module wrapped by DDP. 48 """ 49 def __getattr__(self, name): 50 try: 51 return super().__getattr__(name) 52 except AttributeError: 53 return getattr(self.module, name)
Wrapper for the DistributedDataParallel class.
Oerrides the __getattr__
method to handle access from the "model" object module wrapped by DDP.
def
train_multi_gpu( model_callable: Callable[[Any], torch.nn.modules.module.Module], model_kwargs: Dict[str, Any], train_dataset_callable: Callable[[Any], torch.utils.data.dataset.Dataset], train_dataset_kwargs: Dict[str, Any], val_dataset_callable: Callable[[Any], torch.utils.data.dataset.Dataset], val_dataset_kwargs: Dict[str, Any], loader_kwargs: Dict[str, Any], iterations: int, find_unused_parameters: bool = True, optimizer_callable: Optional[Callable[[Any], torch.optim.optimizer.Optimizer]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None, lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None, lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, trainer_callable: Optional[Callable] = None, **kwargs) -> None:
108def train_multi_gpu( 109 model_callable: Callable[[Any], torch.nn.Module], 110 model_kwargs: Dict[str, Any], 111 train_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 112 train_dataset_kwargs: Dict[str, Any], 113 val_dataset_callable: Callable[[Any], torch.utils.data.Dataset], 114 val_dataset_kwargs: Dict[str, Any], 115 loader_kwargs: Dict[str, Any], 116 iterations: int, 117 find_unused_parameters: bool = True, 118 optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None, 119 optimizer_kwargs: Optional[Dict[str, Any]] = None, 120 lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None, 121 lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, 122 trainer_callable: Optional[Callable] = None, 123 **kwargs 124) -> None: 125 """Run data parallel training on multiple local GPUs via torch.distributed. 126 127 This function will run training on all available local GPUs in parallel. 128 To use it, the function / classes and keywords for the model and data loaders must be given. 129 Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class 130 may be given, so that they can be instantiated for each training child process. 131 132 Here is an example for training a 2D U-Net on the DSB dataset: 133 ```python 134 from torch_em.model import UNet2d 135 from torch_em.multi_gpu_training import train_multi_gpu 136 from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data 137 138 # Make sure the data is downloaded before starting multi-gpu training. 139 data_root = "/path/to/save/the/training/data" 140 get_dsb_data(data_root, source="reduced", download=True) 141 142 patch_shape = (256, 256) 143 train_multi_gpu( 144 model_callable=UNet2d, 145 model_kwargs={"in_channels": 1, "out_channels": 1}, 146 train_dataset_callable=get_dsb_dataset, 147 train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"}, 148 val_dataset_callable=get_dsb_dataset, 149 val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"}, 150 loader_kwargs={"shuffle": True}, 151 iterations=int(5e4), # Train for 50.000 iterations. 152 ) 153 ``` 154 155 Args: 156 model_callable: Function or class to create the model. 157 model_kwargs: Keyword arguments for `model_callable`. 158 train_dataset_callable: Function or class to create the training dataset. 159 train_dataset_kwargs: Keyword arguments for `train_dataset_callable`. 160 val_dataset_callable: Function or class to create the validation dataset. 161 val_dataset_kwargs: Keyword arguments for `val_dataset_callable`. 162 loader_kwargs: Keyword arguments for the torch data loader. 163 iterations: Number of iterations to train for. 164 find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization. 165 optimizer_callable: Function or class to create the optimizer. 166 optimizer_kwargs: Keyword arguments for `optimizer_callable`. 167 lr_scheduler_callable: Function or class to create the learning rate scheduler. 168 lr_scheduler_kwargs: Keyword arguments for `lr_scheduler_callable`. 169 trainer_callable: Function or class to create the trainer. 170 kwargs: Keyword arguments for `trainer_callable`. 171 """ 172 world_size = torch.cuda.device_count() 173 train = partial( 174 _train_impl, 175 model_callable=model_callable, 176 model_kwargs=model_kwargs, 177 train_dataset_callable=train_dataset_callable, 178 train_dataset_kwargs=train_dataset_kwargs, 179 val_dataset_callable=val_dataset_callable, 180 val_dataset_kwargs=val_dataset_kwargs, 181 loader_kwargs=loader_kwargs, 182 iterations=iterations, 183 find_unused_parameters=find_unused_parameters, 184 optimizer_callable=optimizer_callable, 185 optimizer_kwargs=optimizer_kwargs, 186 lr_scheduler_callable=lr_scheduler_callable, 187 lr_scheduler_kwargs=lr_scheduler_kwargs, 188 trainer_callable=trainer_callable, 189 **kwargs 190 ) 191 torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
Run data parallel training on multiple local GPUs via torch.distributed.
This function will run training on all available local GPUs in parallel. To use it, the function / classes and keywords for the model and data loaders must be given. Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class may be given, so that they can be instantiated for each training child process.
Here is an example for training a 2D U-Net on the DSB dataset:
from torch_em.model import UNet2d
from torch_em.multi_gpu_training import train_multi_gpu
from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data
# Make sure the data is downloaded before starting multi-gpu training.
data_root = "/path/to/save/the/training/data"
get_dsb_data(data_root, source="reduced", download=True)
patch_shape = (256, 256)
train_multi_gpu(
model_callable=UNet2d,
model_kwargs={"in_channels": 1, "out_channels": 1},
train_dataset_callable=get_dsb_dataset,
train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"},
val_dataset_callable=get_dsb_dataset,
val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"},
loader_kwargs={"shuffle": True},
iterations=int(5e4), # Train for 50.000 iterations.
)
Arguments:
- model_callable: Function or class to create the model.
- model_kwargs: Keyword arguments for
model_callable
. - train_dataset_callable: Function or class to create the training dataset.
- train_dataset_kwargs: Keyword arguments for
train_dataset_callable
. - val_dataset_callable: Function or class to create the validation dataset.
- val_dataset_kwargs: Keyword arguments for
val_dataset_callable
. - loader_kwargs: Keyword arguments for the torch data loader.
- iterations: Number of iterations to train for.
- find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization.
- optimizer_callable: Function or class to create the optimizer.
- optimizer_kwargs: Keyword arguments for
optimizer_callable
. - lr_scheduler_callable: Function or class to create the learning rate scheduler.
- lr_scheduler_kwargs: Keyword arguments for
lr_scheduler_callable
. - trainer_callable: Function or class to create the trainer.
- kwargs: Keyword arguments for
trainer_callable
.