Tensorflow:freeze_graph在python不存储变量值(Tensorflow: f

2019-09-27 10:47发布

我试图冻结已经与蟒蛇API训练有素的图形,加载和在C ++中使用它。 我知道我应该使用freeze_graph.py ,看着这里的例子freeze_graph_test.py来学习如何使用它。 我有蟒蛇以下代码:

def load_ckpt_and_save_graph(path_to_ckpt=None, output_dir=None, output_name="subsign_classifier_frozen.pb",
                             meta_graph=None):
    if output_dir is None:
        output_dir = OUTPUT_DIR
    if path_to_ckpt is None:
        path_to_ckpt = tf.train.latest_checkpoint(CKPT_DIR, LATEST_CKPT_FILENAME)
    if meta_graph is None:
        meta_graph = MODEL_FILENAME_PATH

    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(meta_graph)
        saver.restore(sess, path_to_ckpt)
        graph_path = tf.train.write_graph(sess.graph_def, logdir=output_dir, name=output_name, as_text=False)
        output_nodes = ["full_prediction/output_length_prediction", "full_prediction/output_class_prediction_1"]
        freeze_graph.freeze_graph(input_graph=graph_path, input_saver=saver._name, input_binary=True,
                                  input_checkpoint=path_to_ckpt, output_node_names=",".join(output_nodes),
                                  restore_op_name="save/restore_all", filename_tensor_name= "save/Const:0",  
                                  output_graph=output_name, clear_devices=True, initializer_nodes="",
                                  variable_names_blacklist="")
    log_and_print('load_ckpt_and_save_graph/ : Graph saved to %s' % output_name)

此代码运行正常,并打印:

INFO:tensorflow:Froze 39 variables.
Converted 39 variables to const ops.
632 ops in the final graph.
load_ckpt_and_save_graph/ : Graph saved to subsign_classifier_frozen.pb

然而,结果似乎并没有包含对变量的值:生成的文件大小只有2MB(就像由write_graph节省了graph_def,检查点时的重量为57MB),并试图在C ++中运行它(与加载后ReadBinaryProto )抛出以下状态:

Failed precondition: Attempting to use uninitialized value Classes_readout/Classes_logits3/bias
     [[Node: Classes_readout/Classes_logits3/bias/read = Identity[T=DT_FLOAT, _class=["loc:@Classes_readout/Classes_logits3/bias"], _device="/job:localhost/replica:0/task:0/cpu:0"](Classes_readout/Classes_logits3/bias)]]

我想我必须滥用一些freeze_graph的参数,但没有我尝试拿到了更好的东西...

编辑3:

使用freeze_graph作为脚本不会产生一个有效的.pb文件,大小约为19MB的

  • 怎么能是超过检查点小很多? 难道真的在这一切吗?

  • 我真的宁愿让Python版本的工作,做你们有什么想法,为什么不?

文章来源: Tensorflow: freeze_graph in python does not store the variable values