I want to train a model on about 2TB of image data on gcloud storage. I saved the image data as separate tfrecords and tried to use the tensorflow data api following this example
https://medium.com/@moritzkrger/speeding-up-keras-with-tfrecord-datasets-5464f9836c36
But it seems like keras' model.fit(...)
doesn't support validation for tfrecord datasets based on
https://github.com/keras-team/keras/pull/8388
Is there a better approach for processing large amounts of data with keras from ml-engine that I'm missing?
Thanks a lot!
If you are willing to use
tf.keras
instead of actual Keras, you can instantiate aTFRecordDataset
with thetf.data
API and pass that directly tomodel.fit()
. Bonus: you get to stream directly from Google Cloud storage, no need to download the data first:To include validation data, create a
TFRecordDataset
with your validation TFRecords and pass that one to thevalidation_data
argument ofmodel.fit()
. Note: this is possible as of TensorFlow 1.9.Final note: you'll need to specify the
steps_per_epoch
argument. A hack that I use to know the total number of examples in all TFRecordfiles, is to simply iterate over the files and count:Which you can use to compute
steps_per_epoch
: