I have got this deprecation warning while using Model.fit_generator
in tensorflow:
WARNING:tensorflow: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
Please use Model.fit, which supports generators.
How can I use Model.fit
instead of Model.fit_generator
?
Model.fit_generator
is deprecated starting from tensorflow 2.1.0 which is currently is in rc1.
You can find the documentation for tf-2.1.0-rc1 here: https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Model#fit
As you can see the first argument of the Model.fit
can take a generator so just pass it your generator.
As mentioned in the Documentation in tensorflow :
x: Input data.
- It could be: A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors, if the
model has named inputs.
- A tf.data dataset. Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights)
- A generator or keras.utils.Sequence returning (inputs, targets) or (inputs, targets, sample weights). A more detailed description of unpacking behavior for iterator types (Dataset, generator, Sequence) is given below.
you can simply pass the generator to Model.fit as similar to Model.fit_generator
data_gen_train = ImageDataGenerator(rescale=1/255.)
data_gen_valid = ImageDataGenerator(rescale=1/255.)
train_generator = data_gen_train.flow_from_directory(train_dir, target_size=(128,128), batch_size=128, class_mode="binary")
valid_generator = data_gen_valid.flow_from_directory(validation_dir, target_size=(128,128), batch_size=128, class_mode="binary")
model.fit(train_generator, epochs=2, validation_data=valid_generator)
.