tf.get_collection to extract variables of one scop

2020-07-19 02:58发布

I have n (e.g: n=3) scopes and x (e.g: x=4) no of Variables defined in each scope. The scopes are:

model/generator_0
model/generator_1
model/generator_2

Once I compute the loss, I want to extract and provide all the variables from only one of the scope based on a criteria during run-time. Hence the index of the scope idx that I select is an argmin tensor cast into int32

<tf.Tensor 'model/Cast:0' shape=() dtype=int32>

I have already tried:

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string)) 

which obviously did not work. Is there any way to get all the x Variables belonging to that particular scope using idx to pass into the optimizer.

Thanks in advance!

Vignesh Srinivasan

1条回答
乱世女痞
2楼-- · 2020-07-19 03:38

You can do something like this in TF 1.0 rc1 or later:

v = tf.Variable(tf.ones(()))
loss = tf.identity(v)
with tf.variable_scope('adamoptim') as vs:
   optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
print([v.name for v in optim_vars]) #=> prints lists of vars created
查看更多
登录 后发表回答