torch_em.trainer.flashoptim_trainer
1import time 2 3import torch 4 5try: 6 from flashoptim import FlashAdamW, cast_model 7except ImportError: 8 FlashAdamW = None 9 cast_model = None 10 11from .default_trainer import DefaultTrainer 12 13 14class FlashOptimTrainer(DefaultTrainer): 15 """Trainer for training models with FlashOptim optimizers for memory-efficiency. 16 17 The trainer adapts the `DefaultTrainer` for the following reasons: 18 1. Casts the model parameters and input data to bf16 precision (see `torch.bfloat16` for details) 19 2. Sets `mixed_precision` and `compile_model` to `False`. 20 21 NOTE: There are a couple of things to keep in mind: 22 1. Multi-GPU training (eg. using DDP) is currently not supported. 23 2. Gradient clipping cannot be applied to the parameters. 24 3. Gradient scaling (eg. using `torch.amp.GradScaler`) is currently not supported. 25 4. Microbatch accumulation (gradient accumulation) is not possible. 26 27 For details, check out the official repository: https://github.com/databricks/flashoptim. 28 And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research. 29 """ 30 def __init__(self, **kwargs): 31 if FlashAdamW is None: 32 raise ImportError( 33 "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`." 34 ) 35 36 optimizer = kwargs["optimizer"] 37 if not isinstance(optimizer, torch.optim.AdamW): 38 raise ValueError( 39 f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. " 40 "FlashAdamW is a drop-in replacement for AdamW only." 41 ) 42 43 # Cast the model parameters to bf16 precision. 44 lr = optimizer.param_groups[0]["lr"] 45 cast_model(kwargs["model"], dtype=torch.bfloat16) 46 kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr) 47 kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau( 48 kwargs["optimizer"], mode="min", factor=0.5, patience=5 49 ) 50 51 # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'. 52 kwargs["mixed_precision"] = False 53 kwargs["compile_model"] = False # TODO: We should explore compiling the model if it brings an advantange. 54 55 super().__init__(**kwargs) 56 self._kwargs = {} # Required by the serializer. 57 58 def _train_epoch_impl(self, progress, forward_context, backprop): 59 self.model.train() 60 61 n_iter = 0 62 t_per_iter = time.time() 63 for x, y in self.train_loader: 64 # Casts inputs to bf16 precision. 65 x = x.to(self.device, non_blocking=True).to(torch.bfloat16) 66 y = y.to(self.device, non_blocking=True).to(torch.bfloat16) 67 68 self.optimizer.zero_grad() 69 70 with forward_context(): 71 pred, loss = self._forward_and_loss(x, y) 72 73 backprop(loss) 74 75 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 76 if self.logger is not None: 77 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 78 79 self._iteration += 1 80 n_iter += 1 81 if self._iteration >= self.max_iteration: 82 break 83 progress.update(1) 84 85 t_per_iter = (time.time() - t_per_iter) / n_iter 86 return t_per_iter 87 88 def _validate_impl(self, forward_context): 89 self.model.eval() 90 91 metric_val = 0.0 92 loss_val = 0.0 93 94 with torch.no_grad(): 95 for x, y in self.val_loader: 96 # Casts inputs to bf16 precision. 97 x = x.to(self.device, non_blocking=True).to(torch.bfloat16) 98 y = y.to(self.device, non_blocking=True).to(torch.bfloat16) 99 100 with forward_context(): 101 pred, loss = self._forward_and_loss(x, y) 102 metric = self.metric(pred, y) 103 104 loss_val += loss.item() 105 metric_val += metric.item() 106 107 metric_val /= len(self.val_loader) 108 loss_val /= len(self.val_loader) 109 if self.logger is not None: 110 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 111 return metric_val
15class FlashOptimTrainer(DefaultTrainer): 16 """Trainer for training models with FlashOptim optimizers for memory-efficiency. 17 18 The trainer adapts the `DefaultTrainer` for the following reasons: 19 1. Casts the model parameters and input data to bf16 precision (see `torch.bfloat16` for details) 20 2. Sets `mixed_precision` and `compile_model` to `False`. 21 22 NOTE: There are a couple of things to keep in mind: 23 1. Multi-GPU training (eg. using DDP) is currently not supported. 24 2. Gradient clipping cannot be applied to the parameters. 25 3. Gradient scaling (eg. using `torch.amp.GradScaler`) is currently not supported. 26 4. Microbatch accumulation (gradient accumulation) is not possible. 27 28 For details, check out the official repository: https://github.com/databricks/flashoptim. 29 And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research. 30 """ 31 def __init__(self, **kwargs): 32 if FlashAdamW is None: 33 raise ImportError( 34 "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`." 35 ) 36 37 optimizer = kwargs["optimizer"] 38 if not isinstance(optimizer, torch.optim.AdamW): 39 raise ValueError( 40 f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. " 41 "FlashAdamW is a drop-in replacement for AdamW only." 42 ) 43 44 # Cast the model parameters to bf16 precision. 45 lr = optimizer.param_groups[0]["lr"] 46 cast_model(kwargs["model"], dtype=torch.bfloat16) 47 kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr) 48 kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau( 49 kwargs["optimizer"], mode="min", factor=0.5, patience=5 50 ) 51 52 # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'. 53 kwargs["mixed_precision"] = False 54 kwargs["compile_model"] = False # TODO: We should explore compiling the model if it brings an advantange. 55 56 super().__init__(**kwargs) 57 self._kwargs = {} # Required by the serializer. 58 59 def _train_epoch_impl(self, progress, forward_context, backprop): 60 self.model.train() 61 62 n_iter = 0 63 t_per_iter = time.time() 64 for x, y in self.train_loader: 65 # Casts inputs to bf16 precision. 66 x = x.to(self.device, non_blocking=True).to(torch.bfloat16) 67 y = y.to(self.device, non_blocking=True).to(torch.bfloat16) 68 69 self.optimizer.zero_grad() 70 71 with forward_context(): 72 pred, loss = self._forward_and_loss(x, y) 73 74 backprop(loss) 75 76 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 77 if self.logger is not None: 78 self.logger.log_train(self._iteration, loss, lr, x, y, pred, log_gradients=True) 79 80 self._iteration += 1 81 n_iter += 1 82 if self._iteration >= self.max_iteration: 83 break 84 progress.update(1) 85 86 t_per_iter = (time.time() - t_per_iter) / n_iter 87 return t_per_iter 88 89 def _validate_impl(self, forward_context): 90 self.model.eval() 91 92 metric_val = 0.0 93 loss_val = 0.0 94 95 with torch.no_grad(): 96 for x, y in self.val_loader: 97 # Casts inputs to bf16 precision. 98 x = x.to(self.device, non_blocking=True).to(torch.bfloat16) 99 y = y.to(self.device, non_blocking=True).to(torch.bfloat16) 100 101 with forward_context(): 102 pred, loss = self._forward_and_loss(x, y) 103 metric = self.metric(pred, y) 104 105 loss_val += loss.item() 106 metric_val += metric.item() 107 108 metric_val /= len(self.val_loader) 109 loss_val /= len(self.val_loader) 110 if self.logger is not None: 111 self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, pred) 112 return metric_val
Trainer for training models with FlashOptim optimizers for memory-efficiency.
The trainer adapts the DefaultTrainer for the following reasons:
- Casts the model parameters and input data to bf16 precision (see
torch.bfloat16for details) - Sets
mixed_precisionandcompile_modeltoFalse.
NOTE: There are a couple of things to keep in mind:
- Multi-GPU training (eg. using DDP) is currently not supported.
- Gradient clipping cannot be applied to the parameters.
- Gradient scaling (eg. using
torch.amp.GradScaler) is currently not supported. - Microbatch accumulation (gradient accumulation) is not possible.
For details, check out the official repository: https://github.com/databricks/flashoptim. And please cite https://doi.org/10.48550/arXiv.2602.23349 if you use this trainer for your research.
FlashOptimTrainer(**kwargs)
31 def __init__(self, **kwargs): 32 if FlashAdamW is None: 33 raise ImportError( 34 "flashoptim is required for `FlashOptimTrainer`. Please install it using `pip install flashoptim`." 35 ) 36 37 optimizer = kwargs["optimizer"] 38 if not isinstance(optimizer, torch.optim.AdamW): 39 raise ValueError( 40 f"FlashOptimTrainer is currently tested with the AdamW optimizer, got '{type(optimizer).__name__}'. " 41 "FlashAdamW is a drop-in replacement for AdamW only." 42 ) 43 44 # Cast the model parameters to bf16 precision. 45 lr = optimizer.param_groups[0]["lr"] 46 cast_model(kwargs["model"], dtype=torch.bfloat16) 47 kwargs["optimizer"] = FlashAdamW(kwargs["model"].parameters(), lr=lr) 48 kwargs["lr_scheduler"] = torch.optim.lr_scheduler.ReduceLROnPlateau( 49 kwargs["optimizer"], mode="min", factor=0.5, patience=5 50 ) 51 52 # Pinning the values for 'mixed_precision' and 'compile_model' both to 'False'. 53 kwargs["mixed_precision"] = False 54 kwargs["compile_model"] = False # TODO: We should explore compiling the model if it brings an advantange. 55 56 super().__init__(**kwargs) 57 self._kwargs = {} # Required by the serializer.
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit