No shape error in tensorflow graph construction bu

2019-08-03 23:47发布

问题:

There occurs no error in tensorflow graph construction, but I get a shape mismatch error during graph computation in tf.gradients (I guess that the error is in back propagation).

This is the error I get:

InvalidArgumentError (see above for traceback):
Input to reshape is a tensor with 16777216 values, but the requested shape has 4096
[[Node: gradients/truediv_grad/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0 /device:GPU:0"](gradients/truediv_grad/Sum, gradients/truediv_grad/Shape)]]

回答1:

I solved the issue , using two techniques:

1.Apparently if you are creating custom ops and gradients , you need to be very explicit in providing the shape information to tensorflow using set_shape or tf.reshape

2.Also when you are registering your gradient using tf.register_gradient which takes op and grad as inputs, you need to be careful while chaining the gradients i.e dy/dx = dy/dz*dz/dx. Say dy/dz is the custom gradient we have created and dz/dxis the gradient of the previous ops as per the chain rule of differentiation.

tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return cust_grad*grad

I changed this to following:

tf.register_gradient(Mygrad)
def Mygrad(op,grad):
*****do stuff with op.inputs and calculate custom grads say cust_grad or dy/dz****
return tf.matmul(tf.reshape(cust_grad,[calculated_shape]),tf.reshape(grad,expeced_shape))