TensorFlow: Optimize for Inference a SavedModel ex

2019-05-28 16:16发布

问题:

I'm trying to optimize a saved graph for inference, so I can use it in Android.

My first attempt at using the optimize_for_inference script failed with

google.protobuf.message.DecodeError: Truncated message

So my question is whether the input/output nodes are wrong or the script cannot handle SavedModels (although it's the same extension as a frozen graph .pb)

Regarding the first: since with Estimators we provide input_fn instead of the data itself, which should be considered the input? The first tf operation on it? Like:

x = x_dict['gestures']

# Data input is a 1-D vector of x_dim * y_dim features ("pixels")
# Reshape to match format [Height x Width x Channel]
# Tensor input become 4-D: [Batch Size, Height, Width, Channel]
x = tf.reshape(x, shape=[-1, x_dim, y_dim, 1], name='input')

(...)

pred_probs = tf.nn.softmax(logits, name='output')

BTW: if there is something different in loading a SavedModel in Android, I'd like to know too.

Thank you in advance!

回答1:

Update: There are good instructions at https://www.tensorflow.org/mobile/prepare_models which include an explaination of what to do with SavedModels. You can freeze your SavedModel using the --input_saved_model_dir to freeze_graph.py.

They're both protocol buffers (.pb), but unfortunately they're different messages (i.e. different file formats). Theoretically you could first extract a MetaGraph from the SavedModel, then "freeze" the MetaGraph's GraphDef (move variables into constants), then run this script on the frozen GraphDef. In that case you'd want your input_fn to be just placeholders.

You could also add a plus one emoji on one of the "SavedModel support for Android" Github issues. Medium-term we'd like to standardize on SavedModel; sorry you've run into this!