Resetting tensorflow streaming metrics' variab

2019-02-15 21:00发布

问题:

I have a bunch of streaming metrics (tf.metrics.accuracy and custom streaming micro, macro and weighted F1-scores).

During training, I get the kind of plot below (nevermind the overfitting).

This happens because to compute the validation set's metrics I call tf.local_variables_initializer to reset the metrics and only have a value for the validation set.

This implies 2 side effects:

  1. The spikes in the image
  2. In between validations, training metrics keep aggregating even if validation happens every 2 epochs

I could partially solve the situation by having different tensors hold each metric (train vs val). But It would not solve 2.

I therefore have 2 questions:

  • In your experience, is it a behavior you expect to see (or not? solution?)
  • Is there a way to have metrics stream only over the last n batches?

回答1:

This behaviour is expected if you reset the metrics in between training. The train metrics dont agregrate the validation metrics if they are two different ops. I will give an example on how to keep those metrics different and how to reset only one of them.


A toy Example:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

#create two different ops
with tf.name_scope('train'):
   train_acc, train_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                 predictions=tf.argmax(logits,1))
with tf.name_scope('valid'):
   valid_acc, valid_acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                                 predictions=tf.argmax(logits,1))

Training:

#initialize the local variables has it holds the variables used for metrics calculation.
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

# initial state
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))

#0.0
#0.0

The initial states are 0.0 as expected.

Now calling the training op metrics:

#training loop
for _ in range(10):
    sess.run(train_acc_op, {logits:[[0,1,0],[1,0,1]]})  
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
# 1.0
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
# 0.0

Only the training accuracy got updated while the valid accuracy is still 0.0. Calling the valid ops:

for _ in range(10):
    sess.run(valid_acc_op, {logits:[[0,1,0],[0,1,0]]}) 
print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
#0.5
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
#1.0

Here the valid accuracy got updated to a new value while the training accuracy remained unchanged.

Lets reset only the validation ops:

stream_vars_valid = [v for v in tf.local_variables() if 'valid/' in v.name]
sess.run(tf.variables_initializer(stream_vars_valid))

print(sess.run(valid_acc, {logits:[[0,1,0],[1,0,1]]}))
#0.0
print(sess.run(train_acc, {logits:[[0,1,0],[1,0,1]]}))
#1.0

The valid accuracy got reset to zero while the training accuracy remained unchanged.