Mr. Ali Rahimi’s recent talk put the batch normalization paper and the term “internal covariate shift” under the spotlight. I kinda agree with Mr. Rahimi on this one, I too don’t understand the necessity and the benefit of using this term. In this post, I’d like to explain my understanding of batch normalization and also Xavier initialization, which I think is related to batch normalization.
Recall the structure of an artificial neuron, the building block of a deep learning model:
It performs the following 2 steps of computation.
- Step 1, a weighted average of inputs Xi.
- Step 2, an activation value obtained by applying a nonlinear activation function to the weighted average of step 1.
You usually choose from few common activation functions for step 2: TanH, Sigmoid or ReLU.
As shown in the above image, both TanH and Sigmoid have a similar shape that resembles a stretched/squashed “S”, whereas ReLU’s shape is different and artificial. It is so in order to overcome some problems of TanH and Sigmoid.
You might wonder why TanH and Sigmoid look like the way they are? My friend recently pointed me to a paper explaining that. I haven’t finished that paper yet, so I won’t go into details here. But to summarize, the “S” shape has some biological roots, it resembles the way a real neuron reacts to input signals.
If we put the activation functions (TanH and Sigmoid) under the microscope, you can see that an activation function is mostly flat (at both ends). It only curves within a very small range. If you random an input x and feed it into an activation, such as the Sigmoid function, the return is most likely 0 or 1. Only when your x is sampled within the small range close to 0, the activation function can return a value between 0 and 1.
As you may have noticed that deep learning models are very sensitive to weight initialization. Without a careful initialization, the same neuron network trained with the same data, that converges rapidly sometimes, may not converge at all.
And this phenomena is largely due to the above characteristic of activation functions. Recall that an input of an activation function is a weighted sum of inputs Xi. If the weighted sum is out of the sweet range of the activation function, then the output of the neuron is almost fixed at the maximum or the minimum. From the backpropagation’s point of view, the “adjust-ability” of this neuron is very poor in this situation, because the activation function is almost flat at the input point and the gradient evaluated at that point is close to zero. Being close to zero means that you don’t know the direction to adjust the input.
In this case, the neuron is being trapped (often called “saturated”). It will be difficult for the neuron to get out of the trapped situation due to the lack of “adjust-ability”. (ReLU is better in that it has only one flat end.) Therefore, one needs to be careful when initializing a neuron network’s weights. If the initial weights are too large or too small, you will likely start with saturated neurons. This will lead to slow or no convergence.
An initialization method called Xavier was therefore introduced to save the day. The idea is randomizing the initial weights, so that the inputs of each activation function fall within the sweet range of the activation function. Ideally, none of the neurons should start with a trapped situation.
This post is a great material on Xavier initialization. Basically it tries to make sure the distribution of the inputs to each activation function is zero mean and unit variance. To do this, it assumes that the input data has been normalized to the same distribution. The more number of inputs a neuron has, the smaller the initial weights should be, in order to compensate the number of inputs. In a word, the Xavier initialization method tries to initialize weights with a smarter value, such that neurons won’t start training in saturation.
However there is another factor, other than the weights, that may cause a neuron to be trapped in saturation. That’s the inputs X. And you only have some control over the inputs X for the first layer of the network, because X of the first layer is the input data, you can always normalize it. But after you have trained the network for few iterations, the internal weights change. And the the weight changes will result in value changes of the inputs X to the internal layers, for which, you don’t have direct control. It’s very likely that those internal inputs cause neurons to trap at a saturation point, which slows down the training process. This is the so called internal covariate shift.
What can we do about it? Can we modify the weights similar to what Xavier initialization does? No, we can’t, because our weights are learnt from the gradient descent procedure. We can’t artificially change them, as that will mess up the training process. Therefore, we have to alter the inputs X (of a batch), so that after some transformation, the inputs X of a neuron can still fall into the sweet range of the neuron’s activation function.
Another way to think of this is that we squeeze and shift the activation function, so that it can adapt to the inputs X automatically. This is the core idea of batch normalization.
As a concrete example, assuming we have a group of inputs X to an activation function:
As you can see, the inputs (red dots) are trapped in saturation. Assuming our activation function is y = f(x), we can squeeze and shift this function by doing this modification: y = f(x/gamma-beta), where gamma defines how much you squeeze this function and beta is a shift factor.
After this modification, or normalization, the new activation function will look like (red line):
As you can see, inputs X now fall into the sweet range again. So really, batch normalization is to improve the adjust-ability of the neurons. In the training process, both gamma and beta are part of the variables to be trained by gradient descent. The network will learn the best gamma and beta (both variables are vectors) for each neuron. They add more flexibility to the network. For mathematical details, I recommend this article.
As an analogy, suppose you are a gym owner. You recently installed some showers in your gym, because your sweaty customers could really use a shower after a workout. However, after certain time, you started to receive complains. According to the customers, the shower water is too cold.
As you should have noticed, shower valve is usually a knob like the above. If you rotate the knob towards one direction, the water should become warmer, and the opposite direction will give you colder water. But there is a temperature range; you can’t adjust the temperature to very hot or very cold. After certain point, you are trapped at a saturation point where your temperature can’t change further towards the extreme end.
Your customers complained because the current adjustable range of the water temperature doesn’t overlap much with the temperature range people prefer. Your customers think the water is too cold. And the shower valve is very sensitive, a tiny rotation will turn the water into the hottest or the coldest. They can’t have a fine control over the water temperature.
Here, in this analogy, the angles of the valve people prefer during shower are like the inputs X to an activation function. And the valve is the activation function. Given an angle (an input), the activation function (the valve) translates it into a value of water temperature. The water temperature can be viewed as the output of the activation function. Clearly, we are now trapped at a saturation point where most people prefer a water temperature that is even hotter than the maximum temperature the shower head can provide. No matter how much customers turn the shower knob, they cannot achieve a satisfyingly warm temperature.
How do you fix this? Well, for the cold water problem, you basically need to add another hot water pipe to the original water pipe to raise the default or the base water temperature. This is like the shift operation of batch normalization, which basically adjusts the mean of the inputs X. For the second issue where the valve is too sensitive, all you need to do is requesting more rotation to achieve the same adjustment (scaling). This is the “squashing/stretching” part of batch normalization. After these modifications, your customers should be able to find perfect temperatures in a conformable temperature range.