Using stop_gradient with AdamOptimizer in TensorFl

2020-06-25 19:19发布

I am trying to implement a training/finetuning framework when in each backpropagation iteration a certain set of parameters stay fixed. I want to be able to change the set of updating or fixed parameters from iteration to iteration. TensorFlow method tf.stop_gradient, which apparently forces gradients of some parameters to stay zero, is very useful for this purpose and it works perfectly fine with different optimizers if the set of updating or fixed parameters do not change from iterations to iterations. It can also handle varying set of updating or fixed parameters if it is used with stochastic gradient descent. My problem is that tf.stop_gradient cannot handle such cases when being used with Adam optimizer. More specifically, it does keep the gradients of the fixed parameters at zero in the output of tf.compute_gradients, but when applying the gradients (tf.apply_gradients), value of the fixed parameters does change. I suppose this is because the optimiaztion step in Adam optimizer is not zero even if the gradient is zero (based on algorithm 1 in Kingma and Ba's paper). Is there a cheap way of freezing a variable set of parameters in each Adam iteration, without explicitly saving the previous iteration's values of the fixed parameters?

More Details:

Suppose I have a single-layer network with weight matrix variable W and a binary mask matrix placeholder MW that specifies which elements of W should get updated in each iteration (value 1 in the ). Instead of using W to write the input/output relationship of this layer, I modify it as below

masked_W = MW*W + tf.stop_gradient(tf.abs(1-MW)*W)

to mask certain elements of W from having non-zero gradients. Then I use masked_W to form the output of the layer and consequently the loss of the network depends on this masked variable. The point is that MW changes in each iteration. Suppose W is a vector of 4 elements initialized to all-zero vector. Here is what happens:

opt=tf.AdamOptimizer(1e-5)
sess.run(tf.global_variables_initializer())
grads_vars=opt.compute_gradients(loss, W)

# initial value of W=[0,0,0,0]

# first iteration:
MW_val = [0,1,1,0]
feed_dict={MW:MW_val, x: batch_of_data, y_:batch_of_labels}
sess.run(opt.apply_gradients(grads_vars), feed_dict=feed_dict))
# gradient of  W=[0,xx,xx,0]
# new value of W=[0,a,b,0]

where xx are some non-zero gradient values, and a and b are new values of updating elements of W. In the second iteration, we change the value assigned to the binary mask matrix MW to [1,0,0,1], hence we expect to have fixed values for W[1] and W[2] and updating values for W[0] and W[3]. But this is what happens:

# second iteration
MW_val = [1,0,0,1]
feed_dict={MW:MW_val, x: batch_of_data, y_:batch_of_labels}
sess.run(opt.apply_gradients(grads_vars), feed_dict=feed_dict))
# gradient of  W=[xx,0,0,xx]
# new value of W=[c,aa,bb,d]

That is, although the gradients of W[1] and W[2] are zero, they get new values (aa != a and bb != b). When changing the optimizer from Adam to SGD, the values of fixed parameters stay the same as expected.

2条回答
▲ chillily
2楼-- · 2020-06-25 19:31

Alternatively you can pass different subsets of the variables to apply_gradients when you want to update different subsets of variables.

查看更多
一夜七次
3楼-- · 2020-06-25 19:33

I found a solution to my question and am sharing it here in case others will find it useful. After the first iteration, the moments of those parameters that had been updated in the first iteration are already non-zero. Therefore, even if one puts their gradients to zero in the second iteration, they will be updated because of their non-zero momentum tensors. In order to prevent the updates, only using tf.stop_gradient is not enough, we have to remove their momentum as well. In case of Adam optimizer, this can be done through get_slot method of the optimizer: opt.get_slot(par, 'm') and opt.get_slot(par,'v'), where the former and latter give access to the first and second momentum tensors of parameter par, respectively. In the example of the question, we have to add the following lines to freeze W[1] and W[2] in the second iteration:

# momentums of W in the first iteration
m_vals = opt.get_slot(W, 'm')
v_vals = opt.get_slot(W, 'v')
# mask them before running the second iteration
masked_m_vals[[1,2]]=0
masked_v_vals[[1,2]]=0
sess.run(opt.get_slot(W, 'm').assign(masked_m_vals))
sess.run(opt.get_slot(W, 'v').assign(masked_v_vals))

It is better to save the masked momentums, in example above m_vals[[1,2]] and v_vals[[1,2]], so that if in the third iteration we relax the fixing constraint of W[1] and W[2], we can restore their momentums to their original values in the first iteration.

查看更多
登录 后发表回答