I've recently picked up Tensorflow and and have been trying my best to adjust to the environment. It has been nothing but wonderful! However, batch normalization using tf.contrib.layers.batch_norm has been a little tricky. Right now, here is the function I'm using:
def batch_norm(x, phase):
return tf.contrib.layers.batch_norm(x,center = True, scale = True,
is_training = phase, updates_collections = None)
Using this, I followed most documentation (also Q & A) that I've found online and it led me to the following conclusions:
1) is_training should be set to True for training and false for testing. This makes sense! When training, I had convergence (error < 1%, Cifar 10 Dataset).
However during testing, my results are terrible (error > 90%) UNLESS I add (update collections = None) as an argument to the batch norm function above. Only with that as an argument will testing give me the error I expected.
I am also sure to use the following for training:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops): # Ensures, Updating ops will perform before training
with tf.name_scope('Cross_Entropy'):
cross_entropy = tf.reduce_mean( # Implement Cross_Entropy to compute the softmax activation
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) # Cross Entropy: True Output Labels (y_), Softmax output (y_conv)
tf.summary.scalar('cross_entropy', cross_entropy) # Graphical output Cross Entropy
with tf.name_scope('train'):
train_step = tf.train.AdamOptimizer(1e-2).minimize(cross_entropy) # Train Network, Tensorflow minimizes cross_entropy via ADAM Optimization
with tf.name_scope('Train_Results'):
with tf.name_scope('Correct_Prediction'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) # Check if prediction is wrong with tf.equal(CNN_result,True_result)
with tf.name_scope('Accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # Find the percent accuracy, take mean of correct_prediction outputs
tf.summary.scalar('accuracy', accuracy) # Graphical output Classification Accuracy
This should make sure that the batch normalization parameters are updating during training.
So this leads me to believe that update collections = None is just a nice default to my batch normalization function that during testing function will be sure not to adjust any batch normalization parameters.... Am I correct?
Lastly: Is it normal to have good results (Expected Error) when, during the testing phase, having batch normalization turned on AND off? Using the batch norm function above, I was able to train well (is_training = True) and test well (is_training = False). However, during testing (is_training = True) I was still able to get great results. This is just gives me a bad feeling. Could someone explain why this is happening? Or if it should be happening at all?
Thank you for your time!