How can I create TensorProto for TensorFlow in Jav

2019-07-26 23:36发布

Now we're using tensorflow/serving for inference. It exposes gRPC service and we can generate the Java classes from the proto file.

Now we can generate the PreditionService from https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto but how can I construct the TensorProto objects from multiple dimention array.

We have some examples from Python ndarray and C++. It would be great if someone has tried in Java.

There's some work about running TensorFlow in Java. Here's the blog but I'm not sure if it works or how we can use it without dependencies.

1条回答
该账号已被封号
2楼-- · 2019-07-27 00:04

TensorProto supports two representations for the content of the tensor:

  1. The various repeated *_val fields (such as TensorProto.float_val, TensorProto.int_val), which store the content as a linear array of primitive elements, in row-major order.

  2. The TensorProto.tensor_content field, which stores the content as a single byte array, which corresponds to the result of tensorflow::Tensor::AsProtoTensorContent(). (In general, this representation corresponds to the in-memory representation of a tensorflow::Tensor, converted to a byte array, but the DT_STRING type is handled differently.)

It will probably be easier to generate TensorProto objects using the first format, although it is less efficient. Assuming you have a 2-D float array called tensorData in your Java program, you can use the following code as a starting point:

float[][] tensorData = ...;
TensorProto.Builder builder = TensorProto.newBuilder();

// Set the shape and dtype fields.
// ...

// Set the float_val field.
for (int i = 0; i < tensorData.length; ++i) {
    for (int j = 0; j < tensorData[i].length; ++j) {
        builder.addFloatVal(tensorData[i][j]);
    }
}

TensorProto tensorProto = builder.build();
查看更多
登录 后发表回答