Can the map function supplied to `tf.data.Dataset.

2019-02-20 15:54发布

问题:

I'd like to create a tf.data.Dataset.from_generator(...) dataset. I need to pass in a Python generator.

I would like to pass in a property of a previous dataset to the generator like so:

dataset = dataset.interleave(
  map_func=lambda x: tf.data.Dataset.from_generator(generator=lambda: gen(x), output_types=tf.int64),
  cycle_length=2
)

Where I define gen(...) to take a value (which is a pointer to some data such as a filename which gen knows how to access).

This fails because gen receives a tensor object, not a python/numpy value.

Is there a way to resolve the tensor object to a value inside of gen(...)?

The reason for interleaving the generators is so I can manipulate the list of data-pointers/filenames with other dataset operations such as .shuffle() and .repeat() without the need to bake those into the gen(...) function, which would be necessary if I started with the generator directly from the list of data-pointers/filenames.

I want to use the generator because a large number of data values will be generated per data-pointer/filename.

回答1:

TensorFlow now supports passing tensor arguments to the generator:

def map_func(tensor):
    dataset = tf.data.Dataset.from_generator(generator, tf.float32, args=(tensor,))
    return dataset


回答2:

The answer is indeed no. Here is a reference to a couple of relevant git issues (open as of the time of this writing) for further developments on the question:

https://github.com/tensorflow/tensorflow/issues/13101

https://github.com/tensorflow/tensorflow/issues/16343