torch_em.trainer.flashoptim_trainer

  1import time
  2
  3import torch
  4
  5try:
  6    from flashoptim import FlashAdamW, cast_model
  7except ImportError:
  8    FlashAdamW = None
  9    cast_model = None
 10
 11from .default_trainer import DefaultTrainer
 12
 13
 14class FlashOptimTrainer(DefaultTrainer):
 15    """Trainer for training models with FlashOptim optimizers for memory-efficiency.
 16
 17    The trainer adapts the `DefaultTrainer` for the following reasons:
 18    1. Casts the model parameters and input data to bf16 precision (see `torch.bfloat16` for details)
 19    2. Sets `mixed_precision` and `compile_model` to `False`.
 20
 21    NOTE: There are a couple of things to keep in mind:
 22    1. Multi-GPU training (eg. using DDP) is currently not supported.
 23    2. Gradient clipping cannot be applied to the parameters.
 24    3. Gradient scaling (eg. using `torch.amp.GradScaler`) is currently not supported.
 25    4. Microbatch accumulation (gradient accumulation) is not possible.
 26
 27    For details, check out the official repository: https://github.com/databricks/flashoptim.
 28    And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research.
 29    """
 30    def __init__(self, **kwargs):
 31        if FlashAdamW is None:
 32            raise ImportError(
 33                "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`."
 34            )
 35
 36        optimizer = kwargs["optimizer"]
 37        if not isinstance(optimizer, torch.optim.AdamW):
 38            raise ValueError(
 39                f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. "
 40                "FlashAdamW is a drop-in replacement for AdamW only."
 41            )
 42
 43        # Cast the model parameters to bf16 precision.
 44        lr = optimizer.param_groups[0]["lr"]
 45        cast_model(kwargs["model"], dtype=torch.bfloat16)
 46        kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr)
 47        kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau(
 48            kwargs["optimizer"], mode="min", factor=0.5, patience=5
 49        )
 50
 51        # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'.
 52        kwargs["mixed_precision"] = False
 53        kwargs["compile_model"] = False  # TODO: We should explore compiling the model if it brings an advantange.
 54
 55        super().__init__(**kwargs)
 56        self._kwargs = {}  # Required by the serializer.
 57
 58    def _train_epoch_impl(self, progress, forward_context, backprop):
 59        self.model.train()
 60
 61        n_iter = 0
 62        t_per_iter = time.time()
 63        for x, y in self.train_loader:
 64            # Casts inputs to bf16 precision.
 65            x = x.to(self.device, non_blocking=True).to(torch.bfloat16)
 66            y = y.to(self.device, non_blocking=True).to(torch.bfloat16)
 67
 68            self.optimizer.zero_grad()
 69
 70            with forward_context():
 71                pred, loss = self._forward_and_loss(x, y)
 72
 73            backprop(loss)
 74
 75            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 76            if self.logger is not None:
 77                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
 78
 79            self._iteration += 1
 80            n_iter += 1
 81            if self._iteration >= self.max_iteration:
 82                break
 83            progress.update(1)
 84
 85        t_per_iter = (time.time() - t_per_iter) / n_iter
 86        return t_per_iter
 87
 88    def _validate_impl(self, forward_context):
 89        self.model.eval()
 90
 91        metric_val = 0.0
 92        loss_val = 0.0
 93
 94        with torch.no_grad():
 95            for x, y in self.val_loader:
 96                # Casts inputs to bf16 precision.
 97                x = x.to(self.device, non_blocking=True).to(torch.bfloat16)
 98                y = y.to(self.device, non_blocking=True).to(torch.bfloat16)
 99
100                with forward_context():
101                    pred, loss = self._forward_and_loss(x, y)
102                    metric = self.metric(pred, y)
103
104                loss_val += loss.item()
105                metric_val += metric.item()
106
107        metric_val /= len(self.val_loader)
108        loss_val /= len(self.val_loader)
109        if self.logger is not None:
110            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
111        return metric_val
class FlashOptimTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 15class FlashOptimTrainer(DefaultTrainer):
 16    """Trainer for training models with FlashOptim optimizers for memory-efficiency.
 17
 18    The trainer adapts the `DefaultTrainer` for the following reasons:
 19    1. Casts the model parameters and input data to bf16 precision (see `torch.bfloat16` for details)
 20    2. Sets `mixed_precision` and `compile_model` to `False`.
 21
 22    NOTE: There are a couple of things to keep in mind:
 23    1. Multi-GPU training (eg. using DDP) is currently not supported.
 24    2. Gradient clipping cannot be applied to the parameters.
 25    3. Gradient scaling (eg. using `torch.amp.GradScaler`) is currently not supported.
 26    4. Microbatch accumulation (gradient accumulation) is not possible.
 27
 28    For details, check out the official repository: https://github.com/databricks/flashoptim.
 29    And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research.
 30    """
 31    def __init__(self, **kwargs):
 32        if FlashAdamW is None:
 33            raise ImportError(
 34                "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`."
 35            )
 36
 37        optimizer = kwargs["optimizer"]
 38        if not isinstance(optimizer, torch.optim.AdamW):
 39            raise ValueError(
 40                f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. "
 41                "FlashAdamW is a drop-in replacement for AdamW only."
 42            )
 43
 44        # Cast the model parameters to bf16 precision.
 45        lr = optimizer.param_groups[0]["lr"]
 46        cast_model(kwargs["model"], dtype=torch.bfloat16)
 47        kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr)
 48        kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau(
 49            kwargs["optimizer"], mode="min", factor=0.5, patience=5
 50        )
 51
 52        # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'.
 53        kwargs["mixed_precision"] = False
 54        kwargs["compile_model"] = False  # TODO: We should explore compiling the model if it brings an advantange.
 55
 56        super().__init__(**kwargs)
 57        self._kwargs = {}  # Required by the serializer.
 58
 59    def _train_epoch_impl(self, progress, forward_context, backprop):
 60        self.model.train()
 61
 62        n_iter = 0
 63        t_per_iter = time.time()
 64        for x, y in self.train_loader:
 65            # Casts inputs to bf16 precision.
 66            x = x.to(self.device, non_blocking=True).to(torch.bfloat16)
 67            y = y.to(self.device, non_blocking=True).to(torch.bfloat16)
 68
 69            self.optimizer.zero_grad()
 70
 71            with forward_context():
 72                pred, loss = self._forward_and_loss(x, y)
 73
 74            backprop(loss)
 75
 76            lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
 77            if self.logger is not None:
 78                self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True)
 79
 80            self._iteration += 1
 81            n_iter += 1
 82            if self._iteration >= self.max_iteration:
 83                break
 84            progress.update(1)
 85
 86        t_per_iter = (time.time() - t_per_iter) / n_iter
 87        return t_per_iter
 88
 89    def _validate_impl(self, forward_context):
 90        self.model.eval()
 91
 92        metric_val = 0.0
 93        loss_val = 0.0
 94
 95        with torch.no_grad():
 96            for x, y in self.val_loader:
 97                # Casts inputs to bf16 precision.
 98                x = x.to(self.device, non_blocking=True).to(torch.bfloat16)
 99                y = y.to(self.device, non_blocking=True).to(torch.bfloat16)
100
101                with forward_context():
102                    pred, loss = self._forward_and_loss(x, y)
103                    metric = self.metric(pred, y)
104
105                loss_val += loss.item()
106                metric_val += metric.item()
107
108        metric_val /= len(self.val_loader)
109        loss_val /= len(self.val_loader)
110        if self.logger is not None:
111            self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred)
112        return metric_val

Trainer for training models with FlashOptim optimizers for memory-efficiency.

The trainer adapts the DefaultTrainer for the following reasons:

  1. Casts the model parameters and input data to bf16 precision (see torch.bfloat16 for details)
  2. Sets mixed_precision and compile_model to False.

NOTE: There are a couple of things to keep in mind:

  1. Multi-GPU training (eg. using DDP) is currently not supported.
  2. Gradient clipping cannot be applied to the parameters.
  3. Gradient scaling (eg. using torch.amp.GradScaler) is currently not supported.
  4. Microbatch accumulation (gradient accumulation) is not possible.

For details, check out the official repository: https://github.com/databricks/flashoptim. And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research.

FlashOptimTrainer(**kwargs)
31    def __init__(self, **kwargs):
32        if FlashAdamW is None:
33            raise ImportError(
34                "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`."
35            )
36
37        optimizer = kwargs["optimizer"]
38        if not isinstance(optimizer, torch.optim.AdamW):
39            raise ValueError(
40                f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. "
41                "FlashAdamW is a drop-in replacement for AdamW only."
42            )
43
44        # Cast the model parameters to bf16 precision.
45        lr = optimizer.param_groups[0]["lr"]
46        cast_model(kwargs["model"], dtype=torch.bfloat16)
47        kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr)
48        kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau(
49            kwargs["optimizer"], mode="min", factor=0.5, patience=5
50        )
51
52        # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'.
53        kwargs["mixed_precision"] = False
54        kwargs["compile_model"] = False  # TODO: We should explore compiling the model if it brings an advantange.
55
56        super().__init__(**kwargs)
57        self._kwargs = {}  # Required by the serializer.