How to save and restore a tf.estimator.Estimator m

2019-08-15 01:41发布

I started using Tensorflow recently and I try to get use to tf.estimator.Estimator objects. I would like to do something a priori quite natural: after having trained my classifier, i.e. an instance of tf.estimator.Estimator (with the train method), I would like to save it in a file (whatever the extension) and then reload it later to predict the labels for some new data. Since the official documentation recommends to use Estimator APIs, I guess something as important as that should be implemented and documented.

I saw on some other page that the method to do that is export_savedmodel (see the official documentation) but I simply don't understand the documentation. There is no explanation of how to use this method. What is the argument serving_input_fn? I never encountered it in the Creating Custom Estimators tutorial or in any of the tutorials that I read. By doing some googling, I discovered that around a year ago the estimators where defined using an other class (tf.contrib.learn.Estimator) and it looks like the tf.estimator.Estimator is reusing some of the previous APIs. But I don't find clear explanations in the documentation about it.

Could someone please give me a toy example? Or explain me how to define/find this serving_input_fn?

And then how would be load the trained classifier again?

Thank you for your help!

Edit: I discovered that one doesn't necessarily need to use export_savemodel to save the model. It is actually done automatically. Then if we define later a new estimator having the same model_dir argument, it will also automatically restore the previous estimator, as explained here.

1条回答
在下西门庆
2楼-- · 2019-08-15 02:14

As you figured out, estimator automatically saves an restores the model for you during the training. export_savemodel might be useful if you want to deploy you model to the field (for example providing the best model for Tensorflow Serving).

Here is a simple example:

est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)

def serving_input_fn(): inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs)

Basically serving_input_fn is responsible for replacing dataset pipelines with a placeholder. In the deployment you can feed data to this placeholder as the input to your model for inference or prediction.

查看更多
登录 后发表回答