Early stopping with tf.estimator, how?

2019-02-01 10:03发布

I'm using tf.estimator in TensorFlow 1.4 and tf.estimator.train_and_evaluate is great but I need early stopping. What's the prefered way of adding that?

I assume there is some tf.train.SessionRunHook somewhere for this. I saw that there was an old contrib package with a ValidationMonitor that seemed to have early stopping, but it doesn't seem to be around anymore in 1.4. Or will the preferred way in the future be to rely on tf.keras (with which early stopping is really easy) instead of tf.estimator/tf.layers/tf.data, perhaps?

4条回答
迷人小祖宗
2楼-- · 2019-02-01 10:43

Another option that doesn't use hooks is to create a tf.contrib.learn.Experiment (which seems, even though in contrib, to also support the new tf.estimator.Estimator).

Then train via the (apparently experimental) method continuous_train_and_eval with appropriately customized continuous_eval_predicate_fn.

According to the tensorflow docu, the continuous_eval_predicate_fn is

A predicate function determining whether to continue eval after each iteration.

and called with the eval_results from the last evaluation run. For early stopping, use a customized function that keeps as state the current best result and a counter and returns False when the condition for early stopping is reached.

Note added: This approach would be use deprecated methods w/ tensorflow 1.7 (all of tf.contrib.learn is deprecated from that version onwards: https://www.tensorflow.org/api_docs/python/tf/contrib/learn )

查看更多
趁早两清
3楼-- · 2019-02-01 10:47

Good news! tf.estimator now has early stopping support on master and it looks like it will be in 1.10.

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))
查看更多
戒情不戒烟
4楼-- · 2019-02-01 11:05

First, you must name the loss to make it available to the early stopping call. If your loss variable is named "loss" in the estimator, the line

copyloss = tf.identity(loss, name="loss")

right beneath it will work.

Then, create a hook with this code.

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

this compares an exponentially smoothed loss validation with its lowest value, and if it is higher by tolerance, it stops training. If it stops too early, raising tolerance and smoothing will make it stop later. Keep smoothing below one, or it will never stop.

You can replace the logic in after_run with something else if you want to stop based on a different condition.

Now, add this hook to the evaluation spec. Your code should look something like this:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

Important note: The function, run_context.request_stop() is broken in the train_and_evaluate call, and doesn't stop training. So, I raised a value error to stop training. So you have to wrap the train_and_evaluate call in a try catch block like this:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

if you don't do this, the code will crash with an error when training stops.

查看更多
仙女界的扛把子
5楼-- · 2019-02-01 11:06

Yes, there is tf.train.StopAtStepHook:

This hook requests stop after either a number of steps have been executed or a last step has been reached. Only one of the two options can be specified.

You can also extend it and implement your own stopping strategy based on the step results.

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()
查看更多
登录 后发表回答