How to calculate batch normalization with python?

2020-05-03 11:05发布

问题:

When I implement batch normalization in python from scrach, I am confused. Please see A paper demonstrates some figures about normalization methods, I think it may be not correct. The description and figure are both not correct.

Description from the paper:

Figure from the paper: As far as I am concerned, the representation of batch normalization is not correct in the original paper. I post the issue here for discussion. I think the batch normalization should be like the following figure.

The key point is how to calculate mean and std. With feature maps' shape as (batch_size, channel_number, width, height), mean = X.mean(axis=(0, 2, 3), keepdims=True) or mean = X.mean(axis=(0, 1), keepdims=True)

Which one is correct?

回答1:

You should calculate mean and std across all pixels in the images of the batch. So use axis=(0, 2, 3) parameters. If the channels have roughly same distributions - you may calculate mean and std across channels as well. so just use mean() and std() without axes parameter.

The figure in the article is correct - it takes mean and std across H and W (image dimensions) for each batch. Obviously, channel is not shown in the 3d cube.