I am trying to output some summary scalars in an ML engine experiment at both train and eval time. tf.summary.scalar('loss', loss)
is correctly outputting the summary scalars for both training and evaluation on the same plot in tensorboard. However, I am also trying to output other metrics at both train and eval time and they are only outputting at train time. The code immediately follows tf.summary.scalar('loss', loss)
but does not appear to work. For example, the code as follows is only outputting for TRAIN, but not EVAL. The only difference is that these are using custom accuracy functions, but they are working for TRAIN
if mode in (Modes.TRAIN, Modes.EVAL):
loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, outputs, weights)
tf.summary.scalar('loss', loss)
sequence_accuracy = sequence_accuracy(targets, predictions,weights)
tf.summary.scalar('sequence_accuracy', sequence_accuracy)
Does it make any sense why loss would plot in tensorboard for both TRAIN & EVAL, while sequence_accuracy would only plot for TRAIN?
Could this behavior somehow be related to the warning I received "Found more than one metagraph event per run. Overwriting the metagraph with the newest event."?
Because the
summary
node in the graph is just a node. It still needs to be evaluated (outputting a protobuf string), and that string still needs to be written to a file. It's not evaluated in training mode because it's not upstream of thetrain_op
in your graph, and even if it were evaluated, it wouldn't be written to a file unless you specified atf.train.SummarySaverHook
as one of youtraining_chief_hooks
in yourEstimatorSpec
. Because theEstimator
class doesn't assume you want any extra evaluation during training, normally evaluation is only done during the EVAL phase, and you just increasemin_eval_frequency
orcheckpoint_frequency
to get more evaluation datapoints.If you really really want to log a summary during training here's how you'd do it:
But it's better to just increase your eval frequency and make an
eval_metric_ops
for accuracy withtf.metrics.streaming_accuracy