How to only restore variables in the checkpoint in

2020-07-17 14:28发布

In Tensorflow, my model is based on a pre-trained model, and I added a few more variables and remove some in the pre-trained model. When I restore the variables from the checkpoint file, I have to explicitly specify all variables I added to the graph that need to be excluded. For example, I did

exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)

Is there a simpler way to do this? Namely, as long as a variable is not in checkpoint, then don't try to restore.

标签: tensorflow
3条回答
Lonely孤独者°
2楼-- · 2020-07-17 15:02

The only thing that you can do is firstly having the same model as in the checkpoint, secondly restoring the checkpoint values to the same model. After restoring the variables for the same model, you can add new layers, delete existing layers or change the weights of the layers.

But there is an important point that you need to be careful. After added new layers you need to initialize them. If you use tf.global_variables_initializer(), you will lose the values of reloaded layers. So you should only initialize the uninitialized weights, you can use following function for this.

def initialize_uninitialized(sess):
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]

    # for i in not_initialized_vars: # only for testing
    #    print(i.name)

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
查看更多
【Aperson】
3楼-- · 2020-07-17 15:05

This is more full answer, that works for not-distributed setting:

from tensorflow.contrib.framework.python.framework import checkpoint_utils
slim = tf.contrib.slim


def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
    check_var_list = checkpoint_utils.list_variables(checkpoint_path)
    check_var_list = [x[0] for x in check_var_list]
    check_var_set = set(check_var_list)
    vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
    vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
    return vars_in_checkpoint, vars_not_in_checkpoint


def create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint):
    model_ready_for_local_init_op = tf.report_uninitialized_variables(var_list = vars_in_checkpoint)
    model_init_vars_not_in_checkpoint = tf.variables_initializer(vars_not_in_checkpoint)

    restoration_saver = tf.train.Saver(vars_in_checkpoint)
    eg_scaffold = tf.train.Scaffold(saver=restoration_saver,
                                    ready_for_local_init_op = model_ready_for_local_init_op,
                                    local_init_op = model_init_vars_not_in_checkpoint)
    return eg_scaffold


all_vars = slim.get_variables()
ckpoint_file = tf.train.latest_checkpoint(output_chkpt_dir)
vars_in_checkpoint, vars_not_in_checkpoint = scan_checkpoint_for_vars(ckpoint_file, all_vars)
is_checkpoint_complete = len(vars_not_in_checkpoint) == 0

# Create session that can handle current checkpoint
if (is_checkpoint_complete):
    # Checkpoint is full - all variables can be found there
    print('Using normal session')
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
                                             save_checkpoint_secs = save_checkpoint_secs,
                                             save_summaries_secs = save_summaries_secs)
else:
    # Checkpoint is partial - some variables need to be initialized
    print('Using easy going session')
    eg_scaffold =  create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint)
    # Save all variables to next checkpoint
    saver = tf.train.Saver()
    hooks = [tf.train.CheckpointSaverHook(checkpoint_dir = output_chkpt_dir,
                                          save_secs = save_checkpoint_secs,
                                          saver = saver)]
    # Such session is a little slower during the first iteration
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
                                             scaffold = eg_scaffold,
                                             hooks = hooks,
                                             save_summaries_secs = save_summaries_secs,
                                             save_checkpoint_secs = None)

with sess:
    .....
查看更多
相关推荐>>
4楼-- · 2020-07-17 15:18

You should first find out all those variable that are useful(meaning also in your graph) and then add the joint set of the intersection of the two from the checkpoint rather than all from it.

variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir))) 

then restore it after defining a saver like this:

temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
查看更多
登录 后发表回答