How to modify the return tensor from tf.nn.embeddi

2019-07-26 17:42发布

问题:

I want to use scatter_nd_update to change the content of the tensor returned from tf.nn.embedding_lookup(). However, the returned tensor is not mutable, and the scatter_nd_update() require an mutable tensor as input. I spent a lot of time trying to find a solution, including using gen_state_ops._temporary_variable and using tf.sparse_to_dense, unfortunately all failed.

I wonder is there a beautiful solution toward it?

with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            updates = tf.constant(0,shape=[embedding_size])
            for i in range(1,sequence_length - 2):
                indices = [None,i]
                tf.scatter_nd_update(self.embedded_chars,indices,updates)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

回答1:

tf.nn.embedding_lookup simply returns the slice of the larger matrix, so the simplest solution is to update the value of that matrix itself, in your case it's self.W:

self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)

Since it's a variable, it is compliant with tf.scatter_nd_update. Note that you can't update just any tensor, only variables.

Another option is to create a new variable just for the selected slice, assign self.embedded_chars to it and perform an update afterwards.


Caveat: in both cases, you're blocking the gradients to train the embedding matrix, so double check that overwriting the learned value is really what you want.



回答2:

This problem rooted from not clearly understand the tensor and variable in the tensorflow context. Later with more knowledge of the tensor, the solution came to my mind is:

   with tf.device('/cpu:0'), tf.name_scope("embedding"):
        self.W = tf.Variable(
            tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
            name="W")
        self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
        for i in range(0,sequence_length - 1,2):
            self.tslice = tf.slice(self.embedded_chars,[0,i,0],[0,1,128])
            self.tslice2 = tf.slice(self.embedded_chars,[0,i+1,0],[0,1,128])
            self.tslice3 = tf.slice(self.embedded_chars,[0,i+2,0],[0,1,128])
            self.toffset1 = tf.subtract(self.tslice,self.tslice2)
            self.toffset2 = tf.subtract(self.tslice2,self.tslice3)
            self.tconcat = tf.concat([self.toffset1,self.toffset2],1)
        self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

the function used, tf.slice, tf.subtract, tf.concat all accept tensor as input. Just avoid using function like tf.scatter_nd_update that require variable as input.