How to read output from tensorflow model in java

2019-08-21 15:47发布

问题:

I try to use TensorflowLite with ssdlite_mobilenet_v2_coco model from https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md converted to tflite file to detect objects from camera stream in my android app (java). I execute

    interpreter.run(input, output);

where input is an image converted to ByteBuffer, output is float array - size [1][10][4] to match tensor.

How to convert this float array to some readable output? - e.g. to get coordinates of bounding box, name of an object, probability.

回答1:

Ok I figured it out. Firstly I run in python following commands:

>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")

Tflite model loaded then:

>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()

Now I've got details of how exacly input and output should look like

>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

So input is converted image - shape 300 x 300

>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

And now I've got spec of multiple outputs in this model. I needed to change

interpreter.run(input, output) 

to

interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);

where "inputs" is:

private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer 

And map_of_indices_to_outputs is:

private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);

now after running i've got coordinates of 10 objects in boxes, index of objects (in coco label file) in classes you must add 1 to get right key! and probability in scores.

Hope this helps somebody in the future.