TensorFlow: Restoring Multiple Graphs

2019-03-22 06:01发布

Suppose we have two TensorFlow computation graphs, G1 and G2, with saved weights W1 and W2. Assume we build a new graph G simply by constructing G1 and G2. How can we restore both W1 and W2 for this new graph G?

For a simple example:

import tensorflow as tf

V1 = tf.Variable(tf.zeros([1]))
saver_1 = tf.train.Saver()
V2 = tf.Variable(tf.zeros([1]))
saver_2 = tf.train.Saver()

sess = tf.Session()
saver_1.restore(sess, 'W1')
saver_2.restore(sess, 'W2')

In this example, saver_1 succesfully restores the corresponding V1, but saver_2 fails with a NotFoundError.

1条回答
兄弟一词,经得起流年.
2楼-- · 2019-03-22 06:06

You can probably use two savers where each saver looks for just one of the variables. If you just use tf.train.Saver(), I think it will look for all variables you have defined. You can give it a list of variables to look for by using tf.train.Saver([v1, ...]). For more info, you can read about the tf.train.Saver constructor here: https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver

Here's a simple working example. Suppose you do your computation in a file "save_vars.py" and it has the following code:

import tensorflow as tf

# Graph 1 - set v1 to have value [1.0]
g1 = tf.Graph()
with g1.as_default():
    v1 = tf.Variable(tf.zeros([1]), name="v1")
    assign1 = v1.assign(tf.constant([1.0]))
    init1 = tf.initialize_all_variables()
    save1 = tf.train.Saver()

# Graph 2 - set v2 to have value [2.0]
g2 = tf.Graph()
with g2.as_default():
    v2 = tf.Variable(tf.zeros([1]), name="v2")
    assign2 = v2.assign(tf.constant([2.0]))
    init2 = tf.initialize_all_variables()
    save2 = tf.train.Saver()

# Do the computation for graph 1 and save
sess1 = tf.Session(graph=g1)
sess1.run(init1)
print sess1.run(assign1)
save1.save(sess1, "tmp/v1.ckpt")

# Do the computation for graph 2 and save
sess2 = tf.Session(graph=g2)
sess2.run(init2)
print sess2.run(assign2)
save2.save(sess2, "tmp/v2.ckpt")

If you ensure that you have a tmp directory and run python save_vars.py, you'll get the saved checkpoint files.

Now, you can restore using a file named "restore_vars.py" with the following code:

import tensorflow as tf

# The variables v1 and v2 that we want to restore
v1 = tf.Variable(tf.zeros([1]), name="v1")
v2 = tf.Variable(tf.zeros([1]), name="v2")

# saver1 will only look for v1
saver1 = tf.train.Saver([v1])
# saver2 will only look for v2
saver2 = tf.train.Saver([v2])
with tf.Session() as sess:
    saver1.restore(sess, "tmp/v1.ckpt")
    saver2.restore(sess, "tmp/v2.ckpt")
    print sess.run(v1)
    print sess.run(v2)

and when you run python restore_vars.py, the output should be

[1.]
[2.]

(at least on my computer that's the output). Feel free to post a comment if anything was unclear.

查看更多
登录 后发表回答