How model.fit works in Keras?

2019-08-07 00:46发布

My previous post or error is this one. So, I found a different way of writing the function so it will be Tensorflow compatible. I tested it and it was working fine. However when I want to integrate it into the keras ,I couldn't. This is the solution for my previous post:

graph = tf.Graph()
with graph.as_default():
i = tf.Variable(0)
error = tf.Variable(initial_value=0,dtype=tf.float64)
sol = tf.random_uniform(shape=[10, 36], dtype=tf.float64, 
maxval=1)
error_1 = tf.Variable(initial_value=0,dtype=tf.float64)
final_loss = tf.Variable(0)

def cond(i, sol, error):
    return tf.less(i, 9)
def body(i, sol,error):
    i = tf.add(i, 1)
    print('i',i)
    #sol = tf.add(sol, 1)
    original_reshaped_elem = original_dim* sol[i]
    original_reshaped_elem = tf.reshape(original_reshaped_elem, 
    [DIM,DIM])
    a = tf.reshape(original_reshaped_elem[:,DIM-1], [DIM,1])
    b = tf.reshape(original_reshaped_elem[:,1], [DIM,1])

    original_reshaped_elem = tf.concat 
    ([b,original_reshaped_elem], axis= 1)
    original_reshaped_elem = tf.concat 
    ([original_reshaped_elem,a], axis= 1)

    c= tf.reshape(original_reshaped_elem[DIM-1,:], [1,DIM+2])
    d= tf.reshape(original_reshaped_elem[1,:], [1,DIM+2])
    original_reshaped_elem = tf.concat 
    ([d,original_reshaped_elem],axis=0)
    reshaped_elem_extended = tf.concat 
    ([original_reshaped_elem,c],axis=0)
    print('reshaped shape', reshaped_elem_extended)


    error = 
tf.add(error,tf.norm(tf.norm(reshaped_elem_extended,ord=2,axis=0),ord=2,axis=0))
    error_1 = tf.divide(error, 36)
    return [i, sol, error_1]


with tf.Session(graph=graph) as session:
     tf.global_variables_initializer().run()

result = tf.while_loop(cond, body, [i, sol, error])
final_loss = tf.divide(result[2], 10)
print(final_loss.eval())
print(result[1].eval())

This is how I call it in my model:

result = tf.while_loop(cond, body, [i, inputs, error])
final_loss = tf.divide(result[2], 10)
vae.add_loss(final_loss)

then I get again this error

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

So, I want to know how model.fit works in keras ? Does it instantiate the graph ? I didn't find any clear documentation about how it works, so I can integrate my loss function accordingly.

0条回答
登录 后发表回答