Is there any way to perform a dictionary lookup based on a String tensor in Tensorflow?
In plain Python, I'd do something like
value = dictionary[key]
. Now I'd like to do the same thing at Tensorflow runtime, when I have my key
as a String tensor. Something like
value_tensor = tf.dict_lookup(string_tensor)
would be nice.
You might find tensorflow.contrib.lookup
helpful:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lookup/lookup_ops.py
https://www.tensorflow.org/api_docs/python/tf/contrib/lookup/HashTable
In particular, you can do:
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1
)
out = table.lookup(input_tensor)
table.init.run()
print out.eval()
TensorFlow is a data flow language with no support for data structures other than tensors. There is no map or dictionary type. However, depending on what you need, when you're using the Python wrapper it is possible to maintain a dictionary in the driver process, which executes in Python, and use it to interact with the TensorFlow graph execution. For example, you could execute one step of a TensorFlow graph within a session, return a string value to the Python driver, use it as a key into a dictionary in the driver, and use the retrieved value to determine the next computation to be requested from the session. This is probably not a good solution if the speed of these dictionary lookups is performance critical.