I would like to manage my training with a tf.estimator.Estimator
but have some trouble to use it alongside the tf.data
API.
I have something like this:
def model_fn(features, labels, params, mode):
# Defines model's ops.
# Initializes with tf.train.Scaffold.
# Returns an tf.estimator.EstimatorSpec.
def input_fn():
dataset = tf.data.TextLineDataset("test.txt")
# map, shuffle, padded_batch, etc.
iterator = dataset.make_initializable_iterator()
return iterator.get_next()
estimator = tf.estimator.Estimator(model_fn)
estimator.train(input_fn)
As I can't use a make_one_shot_iterator
for my use case, my issue is that input_fn
contains an iterator that should be initialized within model_fn
(here, I use tf.train.Scaffold
to initialize local ops).
Also, I understood that we can't only use input_fn = iterator.get_next
otherwise the other ops will not be added to the same graph.
What is the recommended way to initialize the iterator?
As of TensorFlow 1.5, it is possible to make
input_fn
return atf.data.Dataset
, e.g.:See c294fcfd.
For previous versions, you can add the iterator's initializer in the
tf.GraphKeys.TABLE_INITIALIZERS
collections and rely on the default initializer.