How do I find the variable names and values that a

2019-01-16 15:17发布

I want to see the variables that are saved in a TensorFlow checkpoint along with their values. How can I find the variable names that are saved in a TensorFlow checkpoint?

I used tf.train.NewCheckpointReader which is explained here. But, it is not given in the documentation of TensorFlow. Is there any other way?

4条回答
戒情不戒烟
2楼-- · 2019-01-16 15:26

You can use the inspect_checkpoint.py tool.

So, for example, if you stored the checkpoint in the current directory, then you can print the variables and their values as follows

import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file


latest_ckp = tf.train.latest_checkpoint('./')
print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
查看更多
看我几分像从前
3楼-- · 2019-01-16 15:37

Adding more parameter details to print_tensors_in_checkpoint_file

file_name: not a physical file, just the prefix of filenames

If no tensor_name is provided, prints the tensor names and shapes in the checkpoint file. If tensor_name is provided, prints the content of the tensor.(inspect_checkpoint.py)

If all_tensor_names is True, Prints all the tensor names

If all_tensor is 'True`, Prints all the tensor names and the corresponding content.

N.B. all_tensor and all_tensor_names will override tensor_name

查看更多
The star\"
4楼-- · 2019-01-16 15:42

Example usage:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import os
checkpoint_path = os.path.join(model_dir, "model.ckpt")

# List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')

# List contents of v0 tensor.
# Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')

# List contents of v1 tensor.
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')

Update: all_tensors argument was added to print_tensors_in_checkpoint_file since Tensorflow 0.12.0-rc0 so you may need to add all_tensors=False or all_tensors=True if required.

Alternative method:

from tensorflow.python import pywrap_tensorflow
import os

checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

Hope it helps.

查看更多
走好不送
5楼-- · 2019-01-16 15:43

A few more details.

If your model is saved using V2 format, for example, if we have the following files in the directory /my/dir/

model-10000.data-00000-of-00001
model-10000.index
model-10000.meta

then the file_name parameter should only be the prefix, that is

print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)

See https://github.com/tensorflow/tensorflow/issues/7696 for a discussion.

查看更多
登录 后发表回答