Tensorflow Estimator: using predict() function in

2019-08-21 07:26发布

问题:

I have successfully (I hope) trained and evaluated a model using the tf.Estimator where I reach a train/eval accuracy of around 83-85%. So now, I would like to test my model on a separate dataset using the predict() function call in the Estimator class. Preferably I would like to do this in a separate script.

I've at this which says that I need to export as a SavedModel, but is this really necessary? Looking at the documentation for the Estimator class, it seems like I can just pass the path to my checkpoint and graph files via the model_dir parameter. Has anyone any experience with this? When I run my model on the same dataset I used for validation, I do not obtain the same performance as during the validation phase... :-(

回答1:

I think you just need a separate file containing your model_fn definition. Than you instantiate the same estimator class in another script, using the same model_fn definition and the same model_dir.

That works because the Estimator API recovers the tf.Graph definitions and the latest model.ckpt files by itself so you are able to continue training, evaluation and prediction.