Input multiple files into Tensorflow dataset

2019-04-01 11:43发布

问题:

I have the following input_fn.

def input_fn(filenames, batch_size):
    # Create a dataset containing the text lines.
    dataset = tf.data.TextLineDataset(filenames).skip(1)

    # Parse each line.
    dataset = dataset.map(_parse_line)

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(10000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

It works great if filenames=['file1.csv'] or filenames=['file2.csv']. It gives me an error if filenames=['file1.csv', 'file2.csv']. In Tensorflow documentation it says filenames is a tf.string tensor containing one or more filenames. How should I import multiple files?

Following is the error. It seems it's ignoring the .skip(1) in the input_fn above:

InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
 [[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]

回答1:

You have the right idea using tf.data.TextLineDataset. What your current implementation does however, is yield every line of every file in its input tensor of filenames except the first one of the first file. The way you are skipping the first line now only affects the very first line in the very first file. In the second file, the first line is not skipped.

Based on the example on the Datasets guide, you should adapt your code to first create a regular Dataset from the filenames, then run flat_map on each filename to read it using TextLineDataset, simultaneously skipping the first row:

d = tf.data.Dataset.from_tensor_slices(filenames) 
# get dataset from each file, skipping first line of each file
d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1))
d = d.map(_parse_line) # And whatever else you need to do

Here, flat_map creates a new dataset from every element of the original dataset by reading the contents of the file and skipping the first line.