tensorflow: check if a scalar boolean tensor is Tr

2020-07-10 11:43发布

I want to control the execution of a function using a placeholder, but keep getting an error "Using a tf.Tensor as a Python bool is not allowed". Here is the code that produces this error:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

I changed if c to if c is not None without luck. How can I control foo by turning on and off the placeholder a then?

Update: as @nessuno and @nemo point out, we must use tf.cond instead of if..else. The answer to my question is to re-design my function like this:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

3条回答
smile是对你的礼貌
2楼-- · 2020-07-10 11:43

You have to use tf.cond to define a conditional operation within the graph and change, thus, the flow of the tensors.

import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10

查看更多
一纸荒年 Trace。
3楼-- · 2020-07-10 11:43

A simpler way to tackle it:

In [50]: a = tf.placeholder(tf.bool)                                                                                                                                                                                 

In [51]: is_true = tf.count_nonzero([a])                                                                                                                                                                             

In [52]: sess.run(is_true, {a: True})                                                                                                                                                                                
Out[52]: 1

In [53]: sess.run(is_true, {a: False})
Out[53]: 0
查看更多
女痞
4楼-- · 2020-07-10 11:58

The actual execution is not done in Python but in the TensorFlow backend which you supply with the computation graph it is supposed to execute. This means that every condition and flow control you want to apply has to be formulated as a node in the computation graph.

For if conditions there is the cond operation:

b = tf.cond(c, 
           lambda: tf.constant(10), 
           lambda: tf.constant(0))
查看更多
登录 后发表回答