Flatten a dataset in TensorFlow

2019-07-07 21:16发布

问题:

I am trying to convert a dataset in TensorFlow to have several single-valued tensors. The dataset currently looks like this:

[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...

After the transformation it should look like this:

[12] [43] [64] [34] [45] [2] [13] [54] [34] [65] [34] [67] [87] [12] ...

My initial idea was using flat_map on the data set and then converting each tensor to a list of tensors using reshape and unstack:

output_labels = self.dataset.flat_map(convert_labels)

...

def convert_labels(tensor):
    id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
    return tf.data.Dataset.from_tensors(id_list)

However the shape of each tensor is only partially known (i.e. (?, 1)) which is why the unstack operation fails. Is there any way to still "concat" the different tensors without explicitly iterating over them?

回答1:

Your solution is very close, but Dataset.flat_map() takes a function that returns a tf.data.Dataset object, rather than a list of tensors. Fortunately, the Dataset.from_tensor_slices() method works for exactly your use case, because it can split a tensor into a variable number of elements:

output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

Note that the tf.contrib.data.unbatch() transformation implements the same functionality, and has a slightly more efficient implementation in the current master branch of TensorFlow (will be included in the 1.9 release):

output_labels = self.dataset.apply(tf.contrib.data.unbatch())