When I implement batch normalization in python from scrach, I am confused. Please see A paper demonstrates some figures about normalization methods, I think it may be not correct. The description and figure are both not correct.
Description from the paper:
Figure from the paper: As far as I am concerned, the representation of batch normalization is not correct in the original paper. I post the issue here for discussion. I think the batch normalization should be like the following figure.
The key point is how to calculate mean and std.
With feature maps' shape as (batch_size, channel_number, width, height)
,
mean = X.mean(axis=(0, 2, 3), keepdims=True)
or
mean = X.mean(axis=(0, 1), keepdims=True)
Which one is correct?
You should calculate mean and std across all pixels in the images of the batch. So use axis=(0, 2, 3) parameters. If the channels have roughly same distributions - you may calculate mean and std across channels as well. so just use mean() and std() without axes parameter.
The figure in the article is correct - it takes mean and std across H and W (image dimensions) for each batch. Obviously, channel is not shown in the 3d cube.