I am trying to implement a training/finetuning framework when in each backpropagation iteration a certain set of parameters stay fixed. I want to be able to change the set of updating or fixed parameters from iteration to iteration. TensorFlow method tf.stop_gradient
, which apparently forces gradients of some parameters to stay zero, is very useful for this purpose and it works perfectly fine with different optimizers if the set of updating or fixed parameters do not change from iterations to iterations. It can also handle varying set of updating or fixed parameters if it is used with stochastic gradient descent. My problem is that tf.stop_gradient
cannot handle such cases when being used with Adam optimizer. More specifically, it does keep the gradients of the fixed parameters at zero in the output of tf.compute_gradients
, but when applying the gradients (tf.apply_gradients
), value of the fixed parameters does change. I suppose this is because the optimiaztion step in Adam optimizer is not zero even if the gradient is zero (based on algorithm 1 in Kingma and Ba's paper). Is there a cheap way of freezing a variable set of parameters in each Adam iteration, without explicitly saving the previous iteration's values of the fixed parameters?
More Details:
Suppose I have a single-layer network with weight matrix variable W
and a binary mask matrix placeholder MW
that specifies which elements of W
should get updated in each iteration (value 1 in the ). Instead of using W
to write the input/output relationship of this layer, I modify it as below
masked_W = MW*W + tf.stop_gradient(tf.abs(1-MW)*W)
to mask certain elements of W
from having non-zero gradients. Then I use masked_W
to form the output of the layer and consequently the loss of the network depends on this masked variable. The point is that MW
changes in each iteration. Suppose W
is a vector of 4 elements initialized to all-zero vector. Here is what happens:
opt=tf.AdamOptimizer(1e-5)
sess.run(tf.global_variables_initializer())
grads_vars=opt.compute_gradients(loss, W)
# initial value of W=[0,0,0,0]
# first iteration:
MW_val = [0,1,1,0]
feed_dict={MW:MW_val, x: batch_of_data, y_:batch_of_labels}
sess.run(opt.apply_gradients(grads_vars), feed_dict=feed_dict))
# gradient of W=[0,xx,xx,0]
# new value of W=[0,a,b,0]
where xx
are some non-zero gradient values, and a
and b
are new values of updating elements of W
. In the second iteration, we change the value assigned to the binary mask matrix MW
to [1,0,0,1], hence we expect to have fixed values for W[1]
and W[2]
and updating values for W[0]
and W[3]
. But this is what happens:
# second iteration
MW_val = [1,0,0,1]
feed_dict={MW:MW_val, x: batch_of_data, y_:batch_of_labels}
sess.run(opt.apply_gradients(grads_vars), feed_dict=feed_dict))
# gradient of W=[xx,0,0,xx]
# new value of W=[c,aa,bb,d]
That is, although the gradients of W[1]
and W[2]
are zero, they get new values (aa != a
and bb != b
). When changing the optimizer from Adam to SGD, the values of fixed parameters stay the same as expected.
Alternatively you can pass different subsets of the variables to apply_gradients when you want to update different subsets of variables.
I found a solution to my question and am sharing it here in case others will find it useful. After the first iteration, the moments of those parameters that had been updated in the first iteration are already non-zero. Therefore, even if one puts their gradients to zero in the second iteration, they will be updated because of their non-zero momentum tensors. In order to prevent the updates, only using
tf.stop_gradient
is not enough, we have to remove their momentum as well. In case of Adam optimizer, this can be done throughget_slot
method of the optimizer:opt.get_slot(par, 'm')
andopt.get_slot(par,'v')
, where the former and latter give access to the first and second momentum tensors of parameterpar
, respectively. In the example of the question, we have to add the following lines to freezeW[1]
andW[2]
in the second iteration:It is better to save the masked momentums, in example above
m_vals[[1,2]]
andv_vals[[1,2]]
, so that if in the third iteration we relax the fixing constraint ofW[1]
andW[2]
, we can restore their momentums to their original values in the first iteration.