Java - train loaded tensorflow model

2019-06-12 20:44发布

问题:

Does anyone know if it is possible after a model is loaded into Java from Tensorflow Python to continue training the model? I've come up with this snippet of code, but did not work (yes, the output is the same as the input)

for(int i = 0; i < 10000; i++) {
    Tensor cost = b.session().runner().feed("input", input).feed("output", input).fetch("cost").run().get(0);
    System.out.println(result1);
}

This is what is printed out 10000 times:

FLOAT tensor with shape []

And after all, the predictions are the same as they were before.

Moreover, if it is possible to continue training the loaded model, is it possible to update the saved model's weights and biases?

回答1:

You are feeding inputs and fetching the loss; this won't train the model. To do so you'll need to feed batches of data and run the update ops (returned maybe from optimizer.minimize).

It is possible to do this from Java, but the infrastructure in python is more well-developed, including threads to prefetch input data in queues, monitoring when the input is over, saving summaries, and doing distributed training.