Too-small Initialisation — Vanishing Gradient
- Why does it happen: Initialised weights of a neural network are too small
- Result: Premature convergence
- Symptoms: Model performance improves very slowly during training. The training process is also likely to stop very early.
If the initial weights of a neuron are too small relative to the inputs, the gradient of the hidden layers will diminish exponentially as we propagate backward. Or you may also say, vanishing through the layers. To better understand that, let’s assume that we have a three-layer fully connected neural network below where every layer has a sigmoid activation function with zero bias:
If we quickly recall how a sigmoid function and its derivates look like, we can see that the maximum value of the derivative (i.e. gradient) of the sigmoid function is at 2.5 when x=0.
When training a model using gradient descent, we will be updating the weights of the entire neural network using backpropagation by taking partial derivatives of the loss values with respect to the weights of each layer. To take the partial derivates for the network we have above, we need to first know about the mathematical expression of it. Assuming
- a stands for the output of activation function,
- σ stands for the sigmoid function (i.e. our activation function),
- z stands for the output of the neurons of a layer,
- W stands for the weights of each layer
the neural network above can be represented as follows:
And the partial derivatives for updating the weights in the first layer look as follows:
If that looks like a chain of vomit to you, then congratulations, you got the gist of it! An intuitive way of understanding it is that: Because the partial derivatives multiply together according to the chain rule if all the weights are relatively small (e.g. 0.005), then the gradient of the sigmoid will become even smaller and smaller as we propagate backward.
As the gradient will be used for adjusting the weights of each layer, the weights for the earlier layers will hardly get changed due to the vanishing gradient. And when weights become stable, the training process will end i.e. the model cannot learn anything else from the data. However, the model has not converged yet; it is simply suffering from a vanishing gradient!
Ways to alleviate vanishing gradient can be:
- LSTM can solve the vanishing gradient problem with its gates
- Use activation function like ReLu or leaky ReLu which are both less prone to vanishing gradient
- Reduce the number of layers
- Randomly initialize weights at a sufficiently large expected value