I train a model with a placeholder for is_training
:
is_training_ph = tf.placeholder(tf.bool)
however once training and validation are done, I would like to permanently inject a constant of false
in for this value and then "re-optimize" the graph (ie using optimize_for_inference
). Is there something along the lines of freeze_graph
that will do this?
One possibility is to use the tf.import_graph_def()
function and its input_map
argument to rewrite the value of that tensor in the graph. For example, you could structure your program as follows:
with tf.Graph().as_default() as training_graph:
# Build model.
is_training_ph = tf.placeholder(tf.bool, name="is_training")
# ...
training_graph_def = training_graph.as_graph_def()
with tf.Graph().as_default() as temp_graph:
tf.import_graph_def(training_graph_def,
input_map={is_training_ph.name: tf.constant(False)})
temp_graph_def = temp_graph.as_graph_def()
After building temp_graph_def
, you can use it as the input to freeze_graph
.
An alternative, which might be more compatible with the freeze_graph
and optimize_for_inference
scripts (which make assumptions about variable names and checkpoint keys) would be to modify TensorFlow's graph_util.convert_variables_to_constants()
function so that it converts placeholders instead:
def convert_placeholders_to_constants(input_graph_def,
placeholder_to_value_map):
"""Replaces placeholders in the given tf.GraphDef with constant values.
Args:
input_graph_def: GraphDef object holding the network.
placeholder_to_value_map: A map from the names of placeholder tensors in
`input_graph_def` to constant values.
Returns:
GraphDef containing a simplified version of the original.
"""
output_graph_def = tf.GraphDef()
for node in input_graph_def.node:
output_node = tf.NodeDef()
if node.op == "Placeholder" and node.name in placeholder_to_value_map:
output_node.op = "Const"
output_node.name = node.name
dtype = node.attr["dtype"].type
data = np.asarray(placeholder_to_value_map[node.name],
dtype=tf.as_dtype(dtype).as_numpy_dtype)
output_node.attr["dtype"].type = dtype
output_node.attr["value"].CopyFrom(tf.AttrValue(
tensor=tf.contrib.util.make_tensor_proto(data,
dtype=dtype,
shape=data.shape)))
else:
output_node.CopyFrom(node)
output_graph_def.node.extend([output_node])
return output_graph_def
...then you could build training_graph_def
as above, and write:
temp_graph_def = convert_placeholders_to_constants(training_graph_def,
{is_training_ph.op.name: False})