specific kind of slicing over tensor object in ten

2020-02-16 05:09发布

问题:

Summary of the question, Is this kind of slicing and then assignment supported in tensorflow?

out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

Lets give the example, I have a tensor like this:

tf_a1 = tf.Variable([    [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

and I have this one:

tf_a2 = tf.constant([[1, 2, 5],
                    [1, 4, 6],
                    [0, 7, 7],
                    [2, 3, 6],
                    [2, 4, 7]])

Now I want to keep the elements in tf_a1 in which the combination of n (here n is 2) of them (index of them) is in the value of tf_a2. What does it mean?

For example, in tf_a1, in the first column, indexes which has value are: (0,3,4). Is there any row in tf_a2 which contains any combination of these two indexes: (0,3), (0,4) or (3,4). Actually, there is no such row. So all the elements in that column became zero.

Indexes for the second column in tf_a1 is (0,1) (0,5) (1,5). As you see the record (1,5) is available in the tf_a2 in the first row. That's why we keep those in the tf_a1.

This is the correct numpy code:

y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
out = np.zeros_like(tf_a1)
out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

This is the expected output of this numpy code (but I need this in tensorflow):

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

The tensorflow code should be something like this:

y, x = tf.where(tf.count_nonzero(tf.gather(tf_a1, tf_a2, axis=0), axis=1) >= n)
out = tf.zeros_like(tf_a1)
out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

This part of the code tf.gather(tf_a1, tf_a2, axis=0), axis=1) is doing the numpy like slicing tf_a1[tf_a2]

Update 1

The only line which does not work its:

out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

Any idea how can I accomplish this in tensorflow, is this kind of slicing is supported in tensor object at all?

Any help is appreciated:)