-->

How to use tf.data's initializable iterator an

2020-07-14 09:43发布

问题:

All the official google tutorials use the one shot iterator for all the estimator api implementation, i couldnt find any documentation on how to use tf.data's initializable iterator and reinitializable interator instead of one shot iterator.

Can someone kindly show me how to switch between train_data and test_data using tf.data's initializable iterator and reinitializable interator. We need to run a session to use feed dict and switch the dataset in the initializable iterator, its a low level api and its confusing how to use it part of estimator api architecture

PS : I did find that google mentions "Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator."

But is there any work around within the community? or should we just stick with one shot iterator for some good reason

回答1:

To use either initializable or reinitializable iterators, you must create a class that inherits from tf.train.SessionRunHook. This class then have access to the session used by the tf.estimator functions.

Here is quick example that you can adapt to your needs :

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])

This is a modified version from a code I found in a blogpost by Sebastian Pölsterl. Have a look under the "Feeding data to an Estimator via the Dataset API" section.



回答2:

Or you can simply use tf.estimator.train_and_evaluate https://www.tensorflow.org/api_docs/python/tf/estimator/train_and_evaluate It allows you to use validation during training without needing to care about iterator at all.