Loop in tensorflow

2019-08-18 06:51发布

I changed my question to explain my issue better:

I have a function: output_image = my_dunc(x) that x should be like (1, 4, 4, 1)

Please help me to fix the error in this part:

out = tf.Variable(tf.zeros([1, 4, 4, 3]))
index = tf.constant(0)
def condition(index):
    return tf.less(index, tf.subtract(tf.shape(x)[3], 1))
def body(index):
    out[:, :, :, index].assign(my_func(x[:, :, :, index]))
    return tf.add(index, 1), out
out = tf.while_loop(condition, body, [index])

ValueError: The two structures don't have the same nested structure. First structure: type=list str=[] Second structure: type=list str=[<tf.Tensor 'while_10/Add_3:0' shape=() dtype=int32>, <tf.Variable 'Variable_2:0' shape=(1, 4, 4, 3) dtype=float32_ref>] More specifically: The two structures don't have the same number of elements. First structure: type=list str=[<tf.Tensor 'while_10/Identity:0' shape=() dtype=int32>]. Second structure: type=list str=[<tf.Tensor 'while_10/Add_3:0' shape=() dtype=int32>, <tf.Variable 'Variable_2:0' shape=(1, 4, 4, 3) dtype=float32_ref>]

I tested my code and I can get result from out = my_func(x[:, :, :, i]) with different values for i and also while_loop works when I comment the line out[:, :, :, index].assign(my_func(x[:, :, :, index])). Something is wrong in that line.

2条回答
孤傲高冷的网名
2楼-- · 2019-08-18 07:46

I understand that there is no for-loop and so on and just while, why?

According to Implementation of Control Flow in TensorFlow

They should fit well with the dataflow model of TensorFlow, and should be amenable to parallel and distributed execution and automatic differentiation.

I think distributed data flow graphs and Automatic differentiation across devices could have been the constraints leading to the introduction of very few such loop primitives.

There are several diagrams in this doc. that distributed computing experts can understand better. A more thorough explanation is beyond me.

查看更多
小情绪 Triste *
3楼-- · 2019-08-18 07:55

I understand that there is no for-loop and so on and just while, why?

Control structures are hard to get right and hard to optimize. In your case, what if the next example in the same batch has 5 channels. You would need to run 5 loop iterations and either mess up or waste compute resources for the first example with only 3 channels.

You need to think what exactly you are trying to achieve. Commonly you would have different weights for each channel so the system can't just create them out of thin air, they need to be trained properly.

If you just want to apply the same logic 3 times just re-arrange your tensor to be (3, 4, 4, 1). You get 3 results and you do what you want with them.

Usually when you actually need for loops (when handling sequences) you pad the examples so that they all have the same length and generate a model where the loop in unrolled (you would have 3 different operations, one for each iteration of the loop). Look for dynamic_rnn or static_rnn (first one can handle different lengths for each batch).

查看更多
登录 后发表回答