Does anyone know how to extract the top n largest values per row of a rank 2 tensor?
For instance, if I wanted the top 2 values of a tensor of shape [2,4] with values:
[[40, 30, 20, 10], [10, 20, 30, 40]]
The desired condition matrix would look like: [[True, True, False, False],[False, False, True, True]]
Once I have the condition matrix, I can use tf.select to choose actual values.
Thank you for assistance!
you can also use
tf.contrib.framework.argsort
Moreover, you can replace
2
with a 1d tensor so that each row/column can have differentn
values.You can do it using built-in tf.nn.top_k function:
To get boolean
True/False
values, you can first get the k-th value and then usetf.greater_equal
: