How to incorporate validation data in Keras's

2019-04-15 17:52发布

This is an extension to the problem I faced in an earlier post.

I am applying the following code in Keras to do data augmentation (I do not want to use model.fit_generator for the time being , so I loop it manually using datagen.flow).

datagen = ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)


# alternative to model.fit_generator
for e in range(epochs):
    print('Epoch', e)
    batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
        model.fit(x_batch, y_batch)
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break

I would like to incorporate the validation data into the model.fit loop I am running. So for instance I want to replace model.fit(X_batch,y_batch) with something similar to model.fit(X_batch,y_batch, validation_data=(x_val, y_val)) within the for loop.

I am a bit confused on how to incorporate this validation component using datagen.flow in the for loop. Any insights are welcome on how I should proceed.

2条回答
地球回转人心会变
2楼-- · 2019-04-15 18:22

even if the post has a few months I think it could be useful: from version 2.1.5 of Keras it is possible to pass the validation_split parameter to the constructor and then select the subset when using the flow and flow_from_directory methods.

For example:

datagen = ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2)
datagen.fit(x_train)
model = ...
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size, subset='training'),
                    steps_per_epoch=steps_per_epoch,
                    epochs=epochs,
                    validation_data=datagen.flow(x_train, y_train, batch_size=batch_size, subset='validation'),
                    validation_steps=validation_steps)
查看更多
叼着烟拽天下
3楼-- · 2019-04-15 18:40

I assume you have already split your data into training and validation sets. If not, you will have to do so for the following suggestion.

You can create a second data generator using the validation data, then simply iterate over this generator at the same time as the training data generator. I have further help as comments in the code below.

Here is your code, altered to do this, but maybe you will want to alter a few things still:

# unchanged from your code
tr_datagen = ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# create new generator for validation
val_datagen = ImageDataGenerator()    # don't perform augmentation on validation data


# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)

tr_datagen.fit(x_train)    # can leave this out if not standardising or whitening 
val_datagen.fit(x_val)     # can leave this out if not standardising or whitening

# alternative to model.fit_generator
for e in range(epochs):
    print('Epoch', e)
    batches = 0

    # combine both generators, in python 3 using zip()
    for (x_batch, y_batch), (val_x, val_y) in zip(
                                 tr_datagen.flow(x_train, y_train, batch_size=32),
                                 val_datagen.flow(x_val, y_val, batch_size=32)):
        model.fit(x_batch, y_batch, validation_Data=(val_x, val_y))
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break
查看更多
登录 后发表回答