Warm-up is a way to reduce the primacy effect for adaptive schedulers like Adam or AdamW of the early training examples. It allows them to compute the correct gradients from the beginning on. Without it, you may need to run a few extra epochs to get the convergence desired.
Using a too large learning rate may result in numerical instability especially at the very beginning of the training, where parameters are randomly initialized. The warmup strategy increases the learning rate from 0 to the initial learning rate linearly during the initial N epochs or m batches.
In some cases initializing the parameters is not sufficient to guarantee a good solution. This particularly is a problem for some advanced network designs that may lead to unstable optimization problems. We could address this by choosing a sufficiently small learning rate to prevent divergence in the beginning. Unfortunately, this means that progress is slow. Conversely, a large learning rate initially leads to divergence.
A rather simple fix for this dilemma is to use a warmup period during which the learning rate increases to its initial maximum and to cool down the rate until the end of the optimization process. Warmup steps are just a few updates with a low learning rate before/at the beginning of training. After this warmup, you use the regular learning rate (schedule) to train your model to convergence.
In Hasty's Model Playground, If you set the Last Epoch as 1000 for an iteration of 10,000 epochs, using the Warmup factor value for the first 1000 iterations the model will learn the corpus with minimal learning rate than the rate which you've specified in the model. From the 1001th iteration, model will use the previously defined base learning rate
import torch from torch.optim.lr_scheduler import StepLR, ExponentialLR from torch.optim.sgd import SGD from warmup_scheduler import GradualWarmupScheduler if __name__ == '__main__': model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] optim = SGD(model, 0.1) \# scheduler_warmup is chained with schduler_steplr scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) \# this zero gradient update is needed to avoid a warning message, issue #8. optim.zero_grad() optim.step() for epoch in range(1, 20): scheduler_warmup.step(epoch) print(epoch, optim.param_groups['lr']) optim.step() # backward pass (update network)