As a neural network accumulates more parameters, it is more exposed to overfitting. Too many epochs can lead to overfitting of the training dataset, whereas too few may result in an underfit model.
A viable solution is to train on the training dataset, but stop training when performance on the validation dataset begins to deteriorate. Early stopping is one such technique that helps in less wastage of training resources. The Keras module contains a built-in callback designed for this purpose called the Early Stopping Callback.
Using tf.keras.callbacks.EarlyStopping, you can implement the Keras API, the high-level API of TensorFlow.
Keras Callback APIs help in monitoring and tracking when the model is getting trained and can be used as an extra parameter for fit( ) , evaluate( ), predict( ) of the Keras model.
Patience is an important parameter of the Early Stopping Callback.
If the patience parameter is set to X number of epochs or iterations, then the training will terminate only if there is no improvement in the monitor performance measure for X epochs or iterations in a row.
For further understanding, please refer to the explanation of the code implementation below.
Early Stopping is a technique where the training of the model is stopped whenever it is found that the validation error is not decreasing anymore.
Source: Overfitting and Underfitting
.png>)
PyTorch
TensorFLow
In the above TensorFlow code implementation, there is a convolutional
network starting with a Conv1D layer, followed by a MaxPoolinglayer.
Then, the input is flattened and a dense layer with a 10 layer softmax
is added. The model is compiled with the AdamOptimizer, categorical
crossentropy loss, and the categorical accuracy metric is tracked.
Then, we use the tensorflow.keras.callbacks module (early stopping
callback), which is monitoring the performance of the network on the
validation set, and has been created here with the validation_split
keyword argument. It stops the training, depending on how that
performance progresses.
The early stopping callback constructor takes a keyword argument
called monitor, which can be used to set which performance metric to
use. The default is validation loss. In this example, validation
accuracy is set as the performance measure to decide when to terminate
the training.
Coming to the patience argument, which by default is set to zero.
Setting it to zero means that as soon as the performance measure gets
worse from one epoch to the next, the training is terminated. This
might not be ideal since, the model's performance is noisy and might go
up or down from one epoch to the next. What we really care about is that
the general trend should be improving.
In this code example, the patience is set to five epochs, which means
the training will terminate only if there is no improvement in the
monitor performance measure for five epochs in a row.