torch_em.multi_gpu_training

  1import os
  2from functools import partial
  3from typing import Dict, Any, Optional, Callable
  4
  5import torch
  6import torch.utils.data
  7import torch.distributed as dist
  8from torch.nn.parallel import DistributedDataParallel
  9
 10import torch_em
 11
 12
 13def setup(rank, world_size):
 14    """@private
 15    """
 16    os.environ["MASTER_ADDR"] = "localhost"
 17    os.environ["MASTER_PORT"] = "12355"
 18    dist.init_process_group("nccl", rank=rank, world_size=world_size)
 19
 20
 21def cleanup():
 22    """@private
 23    """
 24    dist.destroy_process_group()
 25
 26
 27def _create_data_loader(ds_callable, ds_kwargs, loader_kwargs, world_size, rank):
 28    # Create the dataset.
 29    ds = ds_callable(**ds_kwargs)
 30
 31    # Create the sampler
 32    # Set shuffle on the sampler instead of the loader
 33    shuffle = loader_kwargs.pop("shuffle", False)
 34    sampler = torch.utils.data.distributed.DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=shuffle)
 35
 36    # Create the loader.
 37    loader = torch.utils.data.DataLoader(ds, sampler=sampler, **loader_kwargs)
 38    loader.shuffle = shuffle
 39
 40    return loader
 41
 42
 43class DDP(DistributedDataParallel):
 44    """Wrapper for the DistributedDataParallel class.
 45
 46    Oerrides the `__getattr__` method to handle access from the "model" object module wrapped by DDP.
 47    """
 48    def __getattr__(self, name):
 49        try:
 50            return super().__getattr__(name)
 51        except AttributeError:
 52            return getattr(self.module, name)
 53
 54
 55def _train_impl(
 56    rank: int,
 57    world_size: int,
 58    model_callable: Callable[[Any], torch.nn.Module],
 59    model_kwargs: Dict[str, Any],
 60    train_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
 61    train_dataset_kwargs: Dict[str, Any],
 62    val_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
 63    val_dataset_kwargs: Dict[str, Any],
 64    loader_kwargs: Dict[str, Any],
 65    iterations: int,
 66    find_unused_parameters: bool = True,
 67    optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None,
 68    optimizer_kwargs: Optional[Dict[str, Any]] = None,
 69    lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None,
 70    lr_scheduler_kwargs: Optional[Dict[str, Any]] = None,
 71    trainer_callable: Optional[Callable] = None,
 72    **kwargs
 73):
 74    assert "device" not in kwargs
 75    print(f"Running DDP on rank {rank}.")
 76    setup(rank, world_size)
 77
 78    model = model_callable(**model_kwargs).to(rank)
 79    ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=find_unused_parameters)
 80
 81    if optimizer_callable is not None:
 82        optimizer = optimizer_callable(model.parameters(), **optimizer_kwargs)
 83        kwargs["optimizer"] = optimizer
 84        if lr_scheduler_callable is not None:
 85            lr_scheduler = lr_scheduler_callable(optimizer, **lr_scheduler_kwargs)
 86            kwargs["lr_scheduler"] = lr_scheduler
 87
 88    train_loader = _create_data_loader(train_dataset_callable, train_dataset_kwargs, loader_kwargs, world_size, rank)
 89    val_loader = _create_data_loader(val_dataset_callable, val_dataset_kwargs, loader_kwargs, world_size, rank)
 90
 91    if trainer_callable is None:
 92        trainer_callable = torch_em.default_segmentation_trainer
 93
 94    trainer = trainer_callable(
 95        model=ddp_model,
 96        train_loader=train_loader,
 97        val_loader=val_loader,
 98        device=rank,
 99        rank=rank,
100        **kwargs
101    )
102    trainer.fit(iterations=iterations)
103
104    cleanup()
105
106
107def train_multi_gpu(
108    model_callable: Callable[[Any], torch.nn.Module],
109    model_kwargs: Dict[str, Any],
110    train_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
111    train_dataset_kwargs: Dict[str, Any],
112    val_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
113    val_dataset_kwargs: Dict[str, Any],
114    loader_kwargs: Dict[str, Any],
115    iterations: int,
116    find_unused_parameters: bool = True,
117    optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None,
118    optimizer_kwargs: Optional[Dict[str, Any]] = None,
119    lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None,
120    lr_scheduler_kwargs: Optional[Dict[str, Any]] = None,
121    trainer_callable: Optional[Callable] = None,
122    **kwargs
123) -> None:
124    """Run data parallel training on multiple local GPUs via torch.distributed.
125
126    This function will run training on all available local GPUs in parallel.
127    To use it, the function / classes and keywords for the model and data loaders must be given.
128    Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class
129    may be given, so that they can be instantiated for each training child process.
130
131    Here is an example for training a 2D U-Net on the DSB dataset:
132    ```python
133    from torch_em.model import UNet2d
134    from torch_em.multi_gpu_training import train_multi_gpu
135    from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data
136
137    # Make sure the data is downloaded before starting multi-gpu training.
138    data_root = "/path/to/save/the/training/data"
139    get_dsb_data(data_root, source="reduced", download=True)
140
141    patch_shape = (256, 256)
142    train_multi_gpu(
143        model_callable=UNet2d,
144        model_kwargs={"in_channels": 1, "out_channels": 1},
145        train_dataset_callable=get_dsb_dataset,
146        train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"},
147        val_dataset_callable=get_dsb_dataset,
148        val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"},
149        loader_kwargs={"shuffle": True},
150        iterations=int(5e4),  # Train for 50.000 iterations.
151    )
152    ```
153
154    Args:
155        model_callable: Function or class to create the model.
156        model_kwargs: Keyword arguments for `model_callable`.
157        train_dataset_callable: Function or class to create the training dataset.
158        train_dataset_kwargs: Keyword arguments for `train_dataset_callable`.
159        val_dataset_callable: Function or class to create the validation dataset.
160        val_dataset_kwargs: Keyword arguments for `val_dataset_callable`.
161        loader_kwargs: Keyword arguments for the torch data loader.
162        iterations: Number of iterations to train for.
163        find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization.
164        optimizer_callable: Function or class to create the optimizer.
165        optimizer_kwargs: Keyword arguments for `optimizer_callable`.
166        lr_scheduler_callable: Function or class to create the learning rate scheduler.
167        lr_scheduler_kwargs: Keyword arguments for `lr_scheduler_callable`.
168        trainer_callable: Function or class to create the trainer.
169        kwargs: Keyword arguments for `trainer_callable`.
170    """
171    world_size = torch.cuda.device_count()
172    train = partial(
173        _train_impl,
174        model_callable=model_callable,
175        model_kwargs=model_kwargs,
176        train_dataset_callable=train_dataset_callable,
177        train_dataset_kwargs=train_dataset_kwargs,
178        val_dataset_callable=val_dataset_callable,
179        val_dataset_kwargs=val_dataset_kwargs,
180        loader_kwargs=loader_kwargs,
181        iterations=iterations,
182        find_unused_parameters=find_unused_parameters,
183        optimizer_callable=optimizer_callable,
184        optimizer_kwargs=optimizer_kwargs,
185        lr_scheduler_callable=lr_scheduler_callable,
186        lr_scheduler_kwargs=lr_scheduler_kwargs,
187        trainer_callable=trainer_callable,
188        **kwargs
189    )
190    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)
class DDP(torch.nn.parallel.distributed.DistributedDataParallel):
44class DDP(DistributedDataParallel):
45    """Wrapper for the DistributedDataParallel class.
46
47    Oerrides the `__getattr__` method to handle access from the "model" object module wrapped by DDP.
48    """
49    def __getattr__(self, name):
50        try:
51            return super().__getattr__(name)
52        except AttributeError:
53            return getattr(self.module, name)

Wrapper for the DistributedDataParallel class.

Oerrides the __getattr__ method to handle access from the "model" object module wrapped by DDP.

def train_multi_gpu( model_callable: Callable[[Any], torch.nn.modules.module.Module], model_kwargs: Dict[str, Any], train_dataset_callable: Callable[[Any], torch.utils.data.dataset.Dataset], train_dataset_kwargs: Dict[str, Any], val_dataset_callable: Callable[[Any], torch.utils.data.dataset.Dataset], val_dataset_kwargs: Dict[str, Any], loader_kwargs: Dict[str, Any], iterations: int, find_unused_parameters: bool = True, optimizer_callable: Optional[Callable[[Any], torch.optim.optimizer.Optimizer]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None, lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None, lr_scheduler_kwargs: Optional[Dict[str, Any]] = None, trainer_callable: Optional[Callable] = None, **kwargs) -> None:
108def train_multi_gpu(
109    model_callable: Callable[[Any], torch.nn.Module],
110    model_kwargs: Dict[str, Any],
111    train_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
112    train_dataset_kwargs: Dict[str, Any],
113    val_dataset_callable: Callable[[Any], torch.utils.data.Dataset],
114    val_dataset_kwargs: Dict[str, Any],
115    loader_kwargs: Dict[str, Any],
116    iterations: int,
117    find_unused_parameters: bool = True,
118    optimizer_callable: Optional[Callable[[Any], torch.optim.Optimizer]] = None,
119    optimizer_kwargs: Optional[Dict[str, Any]] = None,
120    lr_scheduler_callable: Optional[Callable[[Any], torch.optim.lr_scheduler._LRScheduler]] = None,
121    lr_scheduler_kwargs: Optional[Dict[str, Any]] = None,
122    trainer_callable: Optional[Callable] = None,
123    **kwargs
124) -> None:
125    """Run data parallel training on multiple local GPUs via torch.distributed.
126
127    This function will run training on all available local GPUs in parallel.
128    To use it, the function / classes and keywords for the model and data loaders must be given.
129    Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class
130    may be given, so that they can be instantiated for each training child process.
131
132    Here is an example for training a 2D U-Net on the DSB dataset:
133    ```python
134    from torch_em.model import UNet2d
135    from torch_em.multi_gpu_training import train_multi_gpu
136    from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data
137
138    # Make sure the data is downloaded before starting multi-gpu training.
139    data_root = "/path/to/save/the/training/data"
140    get_dsb_data(data_root, source="reduced", download=True)
141
142    patch_shape = (256, 256)
143    train_multi_gpu(
144        model_callable=UNet2d,
145        model_kwargs={"in_channels": 1, "out_channels": 1},
146        train_dataset_callable=get_dsb_dataset,
147        train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"},
148        val_dataset_callable=get_dsb_dataset,
149        val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"},
150        loader_kwargs={"shuffle": True},
151        iterations=int(5e4),  # Train for 50.000 iterations.
152    )
153    ```
154
155    Args:
156        model_callable: Function or class to create the model.
157        model_kwargs: Keyword arguments for `model_callable`.
158        train_dataset_callable: Function or class to create the training dataset.
159        train_dataset_kwargs: Keyword arguments for `train_dataset_callable`.
160        val_dataset_callable: Function or class to create the validation dataset.
161        val_dataset_kwargs: Keyword arguments for `val_dataset_callable`.
162        loader_kwargs: Keyword arguments for the torch data loader.
163        iterations: Number of iterations to train for.
164        find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization.
165        optimizer_callable: Function or class to create the optimizer.
166        optimizer_kwargs: Keyword arguments for `optimizer_callable`.
167        lr_scheduler_callable: Function or class to create the learning rate scheduler.
168        lr_scheduler_kwargs: Keyword arguments for `lr_scheduler_callable`.
169        trainer_callable: Function or class to create the trainer.
170        kwargs: Keyword arguments for `trainer_callable`.
171    """
172    world_size = torch.cuda.device_count()
173    train = partial(
174        _train_impl,
175        model_callable=model_callable,
176        model_kwargs=model_kwargs,
177        train_dataset_callable=train_dataset_callable,
178        train_dataset_kwargs=train_dataset_kwargs,
179        val_dataset_callable=val_dataset_callable,
180        val_dataset_kwargs=val_dataset_kwargs,
181        loader_kwargs=loader_kwargs,
182        iterations=iterations,
183        find_unused_parameters=find_unused_parameters,
184        optimizer_callable=optimizer_callable,
185        optimizer_kwargs=optimizer_kwargs,
186        lr_scheduler_callable=lr_scheduler_callable,
187        lr_scheduler_kwargs=lr_scheduler_kwargs,
188        trainer_callable=trainer_callable,
189        **kwargs
190    )
191    torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

Run data parallel training on multiple local GPUs via torch.distributed.

This function will run training on all available local GPUs in parallel. To use it, the function / classes and keywords for the model and data loaders must be given. Optionaly, functions / classes and keywords for the optimizer, learning rate scheduler and trainer class may be given, so that they can be instantiated for each training child process.

Here is an example for training a 2D U-Net on the DSB dataset:

from torch_em.model import UNet2d
from torch_em.multi_gpu_training import train_multi_gpu
from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset, get_dsb_data

# Make sure the data is downloaded before starting multi-gpu training.
data_root = "/path/to/save/the/training/data"
get_dsb_data(data_root, source="reduced", download=True)

patch_shape = (256, 256)
train_multi_gpu(
    model_callable=UNet2d,
    model_kwargs={"in_channels": 1, "out_channels": 1},
    train_dataset_callable=get_dsb_dataset,
    train_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "train"},
    val_dataset_callable=get_dsb_dataset,
    val_dataset_kwargs={"path": data_root, patch_shape: patch_shape, "split": "test"},
    loader_kwargs={"shuffle": True},
    iterations=int(5e4),  # Train for 50.000 iterations.
)
Arguments:
  • model_callable: Function or class to create the model.
  • model_kwargs: Keyword arguments for model_callable.
  • train_dataset_callable: Function or class to create the training dataset.
  • train_dataset_kwargs: Keyword arguments for train_dataset_callable.
  • val_dataset_callable: Function or class to create the validation dataset.
  • val_dataset_kwargs: Keyword arguments for val_dataset_callable.
  • loader_kwargs: Keyword arguments for the torch data loader.
  • iterations: Number of iterations to train for.
  • find_unused_parameters: Whether to find unused parameters of the model to exclude from the optimization.
  • optimizer_callable: Function or class to create the optimizer.
  • optimizer_kwargs: Keyword arguments for optimizer_callable.
  • lr_scheduler_callable: Function or class to create the learning rate scheduler.
  • lr_scheduler_kwargs: Keyword arguments for lr_scheduler_callable.
  • trainer_callable: Function or class to create the trainer.
  • kwargs: Keyword arguments for trainer_callable.