Tensorflow and Batch Normalization with Batch Size

2019-04-03 05:03发布

I have a question about the understanding of the BatchNorm (BN later on).

I have a convnet working nicely, I was writing tests to check for shape and outputs range. And I noticed that when I set the batch_size = 1, my model outputs zeros (logits and activations).

I prototyped the simplest convnet with BN:

Input => Conv + ReLU => BN => Conv + ReLU => BN => Conv Layer + Tanh

The model is initialized with xavier initialization. I guess that BN during training do some calculations that require Batch_size > 1.

I have found an issue in PyTorch that seems to talk about this: https://github.com/pytorch/pytorch/issues/1381

Could anyone explain this ? It's still a little blurry for me.


Example Run:

Important: Tensorlayer Library is required for this script to run: pip install tensorlayer

import tensorflow as tf
import tensorlayer as tl

import numpy as np

def conv_net(inputs, is_training):

    xavier_initilizer = tf.contrib.layers.xavier_initializer(uniform=True)
    normal_initializer = tf.random_normal_initializer(mean=1., stddev=0.02)

    # Input Layers
    network = tl.layers.InputLayer(inputs, name='input')

    fx = [64, 128, 256, 256, 256]

    for i, n_out_channel in enumerate(fx):

        with tf.variable_scope('h' + str(i + 1)):

            network = tl.layers.Conv2d(
                network,
                n_filter    = n_out_channel,
                filter_size = (5, 5),
                strides     = (2, 2),
                padding     = 'VALID',
                act         = tf.identity,
                W_init      = xavier_initilizer,
                name        = 'conv2d'
            )

            network = tl.layers.BatchNormLayer(
                network,
                act        = tf.identity,
                is_train   = is_training,
                gamma_init = normal_initializer,
                name       = 'batch_norm'
            )

            network = tl.layers.PReluLayer(
                layer  = network,
                a_init = tf.constant_initializer(0.2),
                name   ='activation'
            )

    ############# OUTPUT LAYER ###############

    with tf.variable_scope('h' + str(len(fx) + 1)):
        '''

        network = tl.layers.FlattenLayer(network, name='flatten')

        network = tl.layers.DenseLayer(
            network,
            n_units = 100,
            act     = tf.identity,
            W_init  = xavier_initilizer,
            name    = 'dense'
        )

        '''

        output_filter_size = tuple([int(i) for i in network.outputs.get_shape()[1:3]])

        network = tl.layers.Conv2d(
            network,
            n_filter    = 100,
            filter_size = output_filter_size,
            strides     = (1, 1),
            padding     = 'VALID',
            act         = tf.identity,
            W_init      = xavier_initilizer,

            name        = 'conv2d'
        )

        network = tl.layers.BatchNormLayer(
            network,
            act        = tf.identity,
            is_train   = is_training,
            gamma_init = normal_initializer,
            name       = 'batch_norm'
        )

        net_logits = network.outputs

        network.outputs = tf.nn.tanh(
            x        = network.outputs,
            name     = 'activation'
        )

        net_output = network.outputs

    return network, net_output, net_logits


if __name__ == '__main__':

    tf.logging.set_verbosity(tf.logging.DEBUG)

    #################################################
    #                MODEL DEFINITION               #
    #################################################

    PLH_SHAPE = [None, 256, 256, 3]

    input_plh = tf.placeholder(tf.float32, PLH_SHAPE, name='input_placeholder')

    convnet, net_out, net_logits = conv_net(input_plh, is_training=True)


    with tf.Session() as sess:
        tl.layers.initialize_global_variables(sess)

        convnet.print_params(details=True)

        #################################################
        #                  LAUNCH A RUN                 #
        #################################################

        for BATCH_SIZE in [1, 2]:

            INPUT_SHAPE = [BATCH_SIZE, 256, 256, 3]

            batch_data = np.random.random(size=INPUT_SHAPE)

            output, logits = sess.run(
                [net_out, net_logits],
                feed_dict={
                    input_plh: batch_data
                }
            )

            if tf.logging.get_verbosity() == tf.logging.DEBUG:
                print("\n\n###########################")

                print("\nBATCH SIZE = %d\n" % BATCH_SIZE)

            tf.logging.debug("output => Shape: %s - Mean: %e - Std: %f - Min: %f - Max: %f" % (
                output.shape,
                output.mean(),
                output.std(),
                output.min(),
                output.max()
            ))

            tf.logging.debug("logits => Shape: %s - Mean: %e - Std: %f - Min: %f - Max: %f" % (
                logits.shape,
                logits.mean(),
                logits.std(),
                logits.min(),
                logits.max()
            ))

            if tf.logging.get_verbosity() == tf.logging.DEBUG:
                print("###########################")

Gives the following output:

###########################

BATCH SIZE = 1

DEBUG:tensorflow:output => Shape: (1, 1, 1, 100) - Mean: 0.000000e+00 - Std: 0.000000 - Min: 0.000000 - Max: 0.000000
DEBUG:tensorflow:logits => Shape: (1, 1, 1, 100) - Mean: 0.000000e+00 - Std: 0.000000 - Min: 0.000000 - Max: 0.000000
###########################


###########################

BATCH SIZE = 2

DEBUG:tensorflow:output => Shape: (2, 1, 1, 100) - Mean: -1.430511e-08 - Std: 0.760749 - Min: -0.779634 - Max: 0.779634
DEBUG:tensorflow:logits => Shape: (2, 1, 1, 100) - Mean: -4.768372e-08 - Std: 0.998715 - Min: -1.044437 - Max: 1.044437
###########################

2条回答
聊天终结者
2楼-- · 2019-04-03 05:33

Batch Normalization normalizes each output over a complete batch using the following (from original paper).

BatchNorm Formula

So take for example, that you have the following outputs (size 3) for batch size of 2

[2, 4, 6]
[4, 6, 8]

Now mean for each of the output over the batch will be

[3, 5, 7]

Now, look at the numerator in the above formula. It is subtracting mean from each element of the output. But, if the batch size is 1, then mean will exactly be the same as the output, so it will evaluate to 0.

As a side note, even the denominator will also be evaluated to 0 but it seems that tensorflow outputs 0 in a 0/0 situation.

查看更多
不美不萌又怎样
3楼-- · 2019-04-03 05:49

You should probably read an explanation about Batch Normalization, such as this one. You can also take a look at tensorflow's related doc.

Basically, there are 2 ways you can do batch_norm, and both have problems dealing with batch size of 1:

  • using a moving mean and variance pixel per pixel, so they are tensors of the same shape as each sample in your batch. This is the one used in @layog's answer, and (I think) in the original paper, and the most used.

  • Using a moving mean and variance over the entire image / feature space, so they are just vectors (rank 1) of shape (n_channels,).

In both cases, you'll have:

output = gamma * (input - mean) / sigma + beta

Beta is often set to 0 and gamma to 1, since you have linear functions right after BN.

During training, mean and variance are computed accross the current batch, which causes problem when it is of size 1:

  • in the 1st case, you'll get mean=input, so output=0
  • in the 2nd case, mean will be the average value over all pixels, so it's better; but if your width and height are also 1, then you get mean=input again, so you get output=0.

I think most people (and the original method) use the 1st way, which is why you'll get 0 (although TF doc seems to suggest that the 2nd method is usual too). The argument in the link you're providing seems to be considering the 2nd method.

In any case (whichever you're using), with BN you'll only get good results if you use a bigger batch size (say, at least 10).

查看更多
登录 后发表回答