When I implement batch normalization in python from scrach, I am confused. Please see A paper demonstrates some figures about normalization methods, I think it may be not correct. The description and figure are both not correct.
Description from the paper:
Figure from the paper: As far as I am concerned, the representation of batch normalization is not correct in the original paper. I post the issue here for discussion. I think the batch normalization should be like the following figure.
The key point is how to calculate mean and std.
With feature maps' shape as (batch_size, channel_number, width, height)
,
mean = X.mean(axis=(0, 2, 3), keepdims=True)
or
mean = X.mean(axis=(0, 1), keepdims=True)
Which one is correct?