non continuous index slicing on tensor object in t

2019-08-18 22:11发布

I have looked at different ways of slicing in tensorflow, namely, tf.gather and tf.gather_nd. In tf.gather, it just slices over a dimension, and also in tf.gather_nd it just accepts one indices to be applied over the input tensor.

What I need is different, I want to slice over the input tensor using two different tensor;one slices over the rows the second slices over the column and they are not in the same shape necessarily.

For example:

suppose this is my input tensor in which I want to extract part of it.

input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

the second is:

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

The third tensor:

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

Now, I want to slice over input_tf using rows_tf and columns_tf. index [1 2 5] in rows and [1] in columns_tf. Again, rows [1 2 5] with [2] in columns_tf.

Or, [1 4 6] with [2].

Overall, each index in the rows_tf, with the same index in columns_tfwill extract part of the input_tf.

so, the expected output will be:

[[8.3356,    0.,        8.457685 ],
 [0.,        6.103182,  8.602337 ],
 [8.8974,    7.330564,  0.       ],
 [0.,        3.8914037, 5.826657 ],
 [8.8974,    0.,        8.283971 ],
 [6.103182,  3.0614321, 5.826657 ],
 [7.330564,  0.,        8.283971 ],
 [6.103182,  3.8914037, 0.       ]]

for example, here the first row [8.3356, 0., 8.457685 ] is being extracted using

rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)

There were a couple of questions regarding slicing in tensorflow, though they used tf.gather or tf.gather_nd and tf.stack which it did not give my desired output.

No need to mention that in numpy we can easily do that by calling: input_tf[rows_tf, columns_tf].

I also, looked at this advanced indexing which tries to simulate the advanced indexing available in numpy, however it still is not like numpy flexible https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb

This is what I have tried which is not correct:

tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)

the dimension output of this code is (8, 1, 3, 8) which is incorrect totally.

Thanks in advance!

1条回答
Bombasti
2楼-- · 2019-08-18 22:35

The idea is to first get the sparse indices (by concatenating row index and column index) as a list. Then you can use gather_nd to retrieve the values.


tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])
rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])
columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
    tf.tile(columns_tf, multiples=[1, 3]), 
    shape=[-1, 1])
sparse_indices = tf.reshape(
    tf.concat([rows_tf, columns_tf], axis=-1), 
    shape=[-1, 2])

v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  #print 'rows\n', sess.run(rows_tf)
  #print 'columns\n', sess.run(columns_tf)
  print sess.run(v)

Result would be:

[[ 8.3355999   0.          8.45768547]
 [ 0.          6.10318184  8.60233688]
 [ 8.8973999   7.33056402  0.        ]
 [ 0.          3.89140368  5.82665682]
 [ 8.8973999   0.          8.28397083]
 [ 6.10318184  3.06143212  5.82665682]
 [ 7.33056402  0.          8.28397083]
 [ 6.10318184  3.89140368  0.        ]]
查看更多
登录 后发表回答