How can I obtain the last global_step
from a tf.estimator.Estimator
after train(...)
finishes? For instance, a typical Estimator-based training routine might be set up like this:
n_epochs = 10
model_dir = '/path/to/model_dir'
def model_fn(features, labels, mode, params):
# some code to build the model
pass
def input_fn():
ds = tf.data.Dataset() # obviously with specifying a data source
# manipulate the dataset
return ds
run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
for epoch in range(n_epochs):
estimator.train(input_fn=input_fn)
# Now I want to do something which requires to know the last global step, how to get it?
my_custom_eval_method(global_step)
Only the evaluate()
method returns a dictionary containing the global_step
as a field. How can I get the global_step
, if for some reason, I can't have or don't want to use this method?
recently, I found estimator has the api
get_variable_value
Simply create a hook before the training loop: