Memory leak tf.data + Keras

2019-07-23 19:35发布

I have a memory leak in my training pipeline and don't know how to fix it.

I use Tensorflow version: 1.9.0 and Keras (tf) version: 2.1.6-tf with Python 3.5.2

This is how my training pipeline looks like:

for i in range(num_epochs):

    training_data = training_set.make_one_shot_iterator().get_next()
    hist = model.fit(training_data[0],[training_data[1],training_data[2],training_data[3]],
                    steps_per_epoch=steps_per_epoch_train,epochs=1, verbose=1, callbacks=[history, MemoryCallback()])


    # custom validation

It looks like memory of the iterator is not freed after the iterator is exhausted. I have already tried del traininig_data after model.fit. It didn't work.

Can anybody give some hints?

Edit: This is how I create the dataset.

dataset = tf.data.TFRecordDataset(tfrecords_filename)
dataset = dataset.map(map_func=preprocess_fn, num_parallel_calls=8)
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

1条回答
我命由我不由天
2楼-- · 2019-07-23 20:02

Including the repeat() method to reinitialize your iterator might solve your problem. You can take a look at Input Pipeline Performance Guide to figure out what would be the a good optimized order of your methods according to your requirements.

dataset = dataset.shuffle(100)
dataset = dataset.repeat() # Can specify num_epochs as input if needed
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

In case you can afford to do the validation as a part of the fit method, you can use something like the code below and lose the loop altogether to make your life easier.

training_data = training_set.make_one_shot_iterator().get_next()
# val_data refers to your validation data and steps_per_epochs_val refers to no of your validation batches
hist = model.fit(training_data[0],training_data[1],training_data[2],training_data[3]], validation_data=val_data.make_one_shot_iterator(), validation_steps=steps_per_epochs_val, 
       steps_per_epoch=steps_per_epoch_train, epochs=num_epochs, verbose=1, callbacks=[history, MemoryCallback()])

Reference: https://github.com/keras-team/keras/blob/master/examples/mnist_dataset_api.py

查看更多
登录 后发表回答