I want to use scatter_nd_update
to change the content of the tensor returned from tf.nn.embedding_lookup()
. However, the returned tensor is not mutable, and the scatter_nd_update()
require an mutable tensor as input.
I spent a lot of time trying to find a solution, including using gen_state_ops._temporary_variable
and using tf.sparse_to_dense
, unfortunately all failed.
I wonder is there a beautiful solution toward it?
with tf.device('/cpu:0'), tf.name_scope("embedding"):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W")
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
updates = tf.constant(0,shape=[embedding_size])
for i in range(1,sequence_length - 2):
indices = [None,i]
tf.scatter_nd_update(self.embedded_chars,indices,updates)
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
tf.nn.embedding_lookup
simply returns the slice of the larger matrix, so the simplest solution is to update the value of that matrix itself, in your case it's self.W
:
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
Since it's a variable, it is compliant with tf.scatter_nd_update
. Note that you can't update just any tensor, only variables.
Another option is to create a new variable just for the selected slice, assign self.embedded_chars
to it and perform an update afterwards.
Caveat: in both cases, you're blocking the gradients to train the embedding matrix, so double check that overwriting the learned value is really what you want.
This problem rooted from not clearly understand the tensor and variable in the tensorflow context. Later with more knowledge of the tensor, the solution came to my mind is:
with tf.device('/cpu:0'), tf.name_scope("embedding"):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W")
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
for i in range(0,sequence_length - 1,2):
self.tslice = tf.slice(self.embedded_chars,[0,i,0],[0,1,128])
self.tslice2 = tf.slice(self.embedded_chars,[0,i+1,0],[0,1,128])
self.tslice3 = tf.slice(self.embedded_chars,[0,i+2,0],[0,1,128])
self.toffset1 = tf.subtract(self.tslice,self.tslice2)
self.toffset2 = tf.subtract(self.tslice2,self.tslice3)
self.tconcat = tf.concat([self.toffset1,self.toffset2],1)
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
the function used, tf.slice, tf.subtract, tf.concat all accept tensor as input. Just avoid using function like tf.scatter_nd_update that require variable as input.