Load checkpoint and finetuning using tf.estimator.

2020-07-10 09:18发布

We're trying to translate old training code based into a more tf.estimator.Estimator compliant code. In the initial code we fine tune an original model for a target dataset. Only some layers are loaded from the checkpoint before the training takes place using a combination of variables_to_restore and init_fn with the MonitoredTrainingSession. How can one achieve this kind of weight loading with the tf.estimator.Estimator approach ?

标签: tensorflow
2条回答
对你真心纯属浪费
2楼-- · 2020-07-10 09:50

you have two options, first one is simpler:

1- use tf.train.init_from_checkpoint in your model_fn

2- model_fn returns an EstimatorSpec. You can set scaffold viaEstimatorSpec.

查看更多
够拽才男人
3楼-- · 2020-07-10 10:01
import tensorflow as tf    

def model_fn():
  # your model defintion here
  # ...

# specify your saved checkpoint path
checkpoint_path = "model.ckpt"

ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=checkpoint_path)
est = tf.estimator.Estimator(model_fn=model_fn, warm_start_from=ws)
查看更多
登录 后发表回答