Issue NaN with Adam solver

2019-02-05 00:43发布

问题:

I'm training networks with the Adam solver and ran into the problem, that optimization hits 'nan' at some point, but the loss seems to decrease nicely up to that point. It happens only for some specific configurations and after a couple of thousand iterations. For example the network with batch size 5 will have the problem, while with a batch size of one it works. So I started debugging my code:

1) First thing that came to my mind is to check the inputs when the net hits 'nan', but they look reasonable (correctly labled ground truth and input with an okayish value range)

2) While searching I discovered tf.verify_tensor_all_finite(..) and I put that all over my code to see, which tensor first becomes 'nan'. I could narrow down the problem to the following lines:

kernel = tf.verify_tensor_all_finite(kernel, 'kernel')
in_tensor = tf.verify_tensor_all_finite(in_tensor, 'in_tensor')
tmp_result = tf.nn.conv2d_transpose(value=in_tensor, filter=kernel, output_shape=output_shape,
                strides=strides, padding='SAME')
tmp_result = tf.verify_tensor_all_finite(tmp_result, 'convres')

Which throw an error, which reads:

InvalidArgumentError (see above for traceback): convres : Tensor had NaN values
     [[Node: upconv_logits5_fs/VerifyFinite_2/CheckNumerics = CheckNumerics[T=DT_FLOAT, _class=["loc:@upconv_logits5_fs/conv2d_transpose"], message="convres", _device="/job:localhost/replica:0/task:0/gpu:0"](upconv_logits5_fs/conv2d_transpose)]]
     [[Node: Adam/update/_2794 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_154_Adam/update", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Now I'm not sure what happened here.

I guess, that during the forward pass everything went well, because the scalar loss didn't trigger an error and also kernel & input were still valid numbers. It seems as some Adam update node modifies the value of my upconv_logits5_fs towards nan. This transposed convolution op is the very last of my network and therefore the first one to be updated.

I'm working with a tf.nn.softmax_cross_entropy_with_logits() loss and put tf.verify_tensor_all_finite() on all it's in- and outputs, but they don't trigger errors. The only conclusion I can draw is that there might be a numerical issue with the Adam solver.

  • What do you think about that conclusion?
  • Does anybody have any idea how to proceed or what I could try?

Your help is very much appreciated.

EDIT: I was able to solve my problem by increasing the solvers epsilon parameter from 1e-8 to 1e-4. It seemed like some of my parameters tend to have very little to zero variance and that resulted in tf.sqrt(0.0 + epsilon), which caused numerical problems.

回答1:

I had run to the same issue multiple time. the reason behind this problem is by using the softmax and crossentropy. So when you are computing the gradient and diving by zero or inf you are getting nan which is propagating throw all your parameters.

few advises to avoid this problem

  • if error starts increasing then NaN appears afterwards: diverging due to too high learning rate
  • if NaNs appear suddenly: saturating units yielding non-differentiable gradient
  • NaN computation due to log(0)
  • NaN due to floating point issues (to high weights) or activations on the output
  • 0/0, inf/inf, inf*weight...

solutions:

  • reduce learning rate
  • Change the Weight initialization
  • Use L2 norm
  • Safe softmax (small value add to log(x))
  • gradient clipping

In my case learning rate solved the issue but I'm still working to optimize it more



回答2:

One additional step which was not included in the Feras's answer and costed me a day of debugging.

Increasing the precision of your variables. I had a network where a lot of variables were defined as float16. The network worked fine for all the optimizers except of Adam and Adadelt. After hours of debugging I switched to ft.float64 and it worked.



回答3:

This is a numerical stability issue. I would suggest trying a lower learning rate to see if that resolve your problem.