I'm using the tf.unsorted_segment_sum
method of TensorFlow and it works fine when the tensor i give as data have only one line. For example:
tf.unsorted_segment_sum(tf.constant([0.2, 0.1, 0.5, 0.7, 0.8]),
tf.constant([0, 0, 1, 2, 2]), 3)
Gives the right result:
array([ 0.3, 0.5 , 1.5 ], dtype=float32)
The question is, if i use a tensor with several lines, how can I get the results for each line? For instance, if I try a tensor with two lines:
tf.unsorted_segment_sum(tf.constant([[0.2, 0.1, 0.5, 0.7, 0.8],
[0.2, 0.2, 0.5, 0.7, 0.8]]),
tf.constant([[0, 0, 1, 2, 2],
[0, 0, 1, 2, 2]]), 3)
The result i would expect is:
array([ [ 0.3, 0.5 , 1.5 ], [ 0.4, 0.5, 1.5 ] ], dtype=float32)
But what I get is:
array([ 0.7, 1. , 3. ], dtype=float32)
I want to know if someone know how to obtain the result for each line without using a for loop?
Thanks in advance
EDIT:
While the solution below may cover some additional strange uses, this problem can be solved much more easily just by transposing the data. It turns out that, even though
tf.unsorted_segment_sum
does not have an axis parameter, it can work only along one axis, as long as it is the first one. So you can do just as follows:Output:
ORIGINAL POST:
tf.unsorted_segment_sum
does not support working on a single axis. The simplest solution would be to apply the operation to each row and then concatenate them back:However, this has drawbacks: 1) it only works for statically-shaped tensors (that is, you need to have a fixed number of rows) and 2) it may not be as efficient. The first one could be circumvented using a
tf.while_loop
, but, it would be complicated, and also it would require you to concatenate the rows one by one, which is very inefficient. Also, you already stated you want to avoid loops.A better option is to use different ids for each row. For example, you could add to each value in
segment_id
something likenum_segments * row_index
, so you guarantee that each row will have its own set of ids:Then you can apply the operation and the reshape to get the tensor that you want:
Output: