diff --git a/delira/training/callbacks/__init__.py b/delira/training/callbacks/__init__.py index 5ab2b9e8..74495507 100644 --- a/delira/training/callbacks/__init__.py +++ b/delira/training/callbacks/__init__.py @@ -20,3 +20,5 @@ ReduceLROnPlateauCallback as ReduceLROnPlateauCallbackPyTorch from delira.training.callbacks.pytorch_schedulers import StepLRCallback \ as StepLRCallbackPyTorch + from delira.training.callbacks.pytorch_schedulers import \ + OneCycleLRCallback as OneCycleLRCallbackPyTorch diff --git a/delira/training/callbacks/pytorch_schedulers.py b/delira/training/callbacks/pytorch_schedulers.py index 05e1164c..b569a792 100644 --- a/delira/training/callbacks/pytorch_schedulers.py +++ b/delira/training/callbacks/pytorch_schedulers.py @@ -3,7 +3,8 @@ if 'TORCH' in get_backends(): from torch.optim.lr_scheduler import ReduceLROnPlateau, \ - CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR + CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR, \ + OneCycleLR class DefaultPyTorchSchedulerCallback(AbstractCallback): """ @@ -47,6 +48,125 @@ def at_epoch_end(self, trainer, **kwargs): self.scheduler.step(epoch=kwargs.get("curr_epoch", None)) return {} + class OneCycleLRCallback(DefaultPyTorchSchedulerCallback): + """ + Wraps PyTorch's `OneCycleLR` Scheduler as Callback + + """ + + def __init__( + self, + optimizer, + max_lr, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1): + """ + + Parameters + ---------- + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note + that if a value is provided here, then it must be inferred by + providing a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of + steps in the cycle if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. + This is used along with epochs in order to infer the total + number of steps in the cycle if a value for total_steps is + not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) + spent increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the + cycle for each parameter group. Note that momentum is cycled + inversely to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the + cycle for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is + 'max_momentum' and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + last_epoch (int): The index of the last batch. This parameter is + used when resuming a training job. Since `step()` should be + invoked after each batch instead of after each epoch, this + number represents the total number of *batches* computed, + not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the + beginning. + Default: -1 + """ + super().__init__() + self.scheduler = OneCycleLR( + optimizer, + max_lr, + total_steps, + epochs, + steps_per_epoch, + pct_start, + anneal_strategy, + cycle_momentum, + base_momentum, + max_momentum, + div_factor, + final_div_factor, + last_epoch) + + def at_iter_begin(self, trainer, train, + **kwargs): + """ + Executes a single scheduling step + + Parameters + ---------- + trainer : :class:`PyTorchNetworkTrainer` + the trainer class, which can be changed + kwargs : + additional keyword arguments + + Returns + ------- + :class:`PyTorchNetworkTrainer` + modified trainer + + """ + if train: + self.scheduler.step() + + return {} + + def at_epoch_end(self, trainer, **kwargs): + return {} + class ReduceLROnPlateauCallback(DefaultPyTorchSchedulerCallback): """ Wraps PyTorch's `ReduceLROnPlateau` Scheduler as Callback diff --git a/delira/training/predictor.py b/delira/training/predictor.py index 84436cfa..da0cba2f 100644 --- a/delira/training/predictor.py +++ b/delira/training/predictor.py @@ -275,7 +275,7 @@ def predict_data_mgr( batch_list = [] for i, batch in iterable: - self._at_iter_begin(iter_num=i) + Predictor._at_iter_begin(self, iter_num=i) if not batch_list and (n_batches - i) < batchsize: batchsize = n_batches - i