tensorflow: check if a scalar boolean tensor is Tr

2020-07-10 11:26发布

问题:

I want to control the execution of a function using a placeholder, but keep getting an error "Using a tf.Tensor as a Python bool is not allowed". Here is the code that produces this error:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

I changed if c to if c is not None without luck. How can I control foo by turning on and off the placeholder a then?

Update: as @nessuno and @nemo point out, we must use tf.cond instead of if..else. The answer to my question is to re-design my function like this:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

回答1:

You have to use tf.cond to define a conditional operation within the graph and change, thus, the flow of the tensors.

import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10



回答2:

The actual execution is not done in Python but in the TensorFlow backend which you supply with the computation graph it is supposed to execute. This means that every condition and flow control you want to apply has to be formulated as a node in the computation graph.

For if conditions there is the cond operation:

b = tf.cond(c, 
           lambda: tf.constant(10), 
           lambda: tf.constant(0))


回答3:

A simpler way to tackle it:

In [50]: a = tf.placeholder(tf.bool)                                                                                                                                                                                 

In [51]: is_true = tf.count_nonzero([a])                                                                                                                                                                             

In [52]: sess.run(is_true, {a: True})                                                                                                                                                                                
Out[52]: 1

In [53]: sess.run(is_true, {a: False})
Out[53]: 0