How to load two checkpoints in init_fn in slim.lea

2019-04-16 06:58发布

问题:

I want to load two checkpoints while using slim.learning.train. For example,

init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)
slim.learning.train(train_op, log_dir, init_fn=init_fn)

The problem is that I can input only one checkpoint file in model_path. I want to put two checkpoints. I think there can be two possible solutions:

  • Modify the following assign_from_checkpoint_fn function in tf.contrib.framework.assign_from_checkpoint_fn so that model_path can be a list of checkpoint files
  • Merge two checkpoints before. (I didn't find any tool for this)

Is there anyone who help me?

回答1:

I found a solution: we can define our init function using session like this:

flow_init_assign_op, flow_init_feed_dict = slim.assign_from_checkpoint(
            flow_ckpt, flow_var_to_restore)

resnet_init_assign_op, resnet_init_feed_dict = 
slim.assign_from_checkpoint(
            resnet_ckpt, resnet_var_to_restore, ignore_missing_vars=True)

def init_fn(sess):
    sess.run(flow_init_assign_op, flow_init_feed_dict)
    sess.run(resnet_init_assign_op, resnet_init_feed_dict)