I am building a simple Sequential model in Keras (tensorflow backend). During training I want to inspect the individual training batches and model predictions. Therefore, I am trying to create a custom Callback
that saves the model predictions and targets for each training batch. However, the model is not using the current batch for prediction, but the entire training data.
How can I hand over only the current training batch to the Callback
?
And how can I access the batches and targets that the Callback
saves in self.predhis and self.targets?
My current version looks as follows:
callback_list = [prediction_history((self.x_train, self.y_train))]
self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.n_epochs, validation_data=(self.x_val, self.y_val), callbacks=callback_list)
class prediction_history(keras.callbacks.Callback):
def __init__(self, train_data):
self.train_data = train_data
self.predhis = []
self.targets = []
def on_batch_end(self, epoch, logs={}):
x_train, y_train = self.train_data
self.targets.append(y_train)
prediction = self.model.predict(x_train)
self.predhis.append(prediction)
tf.logging.info("Prediction shape: {}".format(prediction.shape))
tf.logging.info("Targets shape: {}".format(y_train.shape))
NOTE: the original accepted answer was wrong as is pointed out in the comment. Since it's accepted and cannot be deleted, I've rewritten it to provide a working answer.
After model compilation, the placeholder tensor for
y_true
is inmodel.targets
andy_pred
is inmodel.outputs
.To save the values of these placeholders at each batch, you can:
on_batch_end
, and store the resulting arrays.Now step 1 is a bit involved because you'll have to add an
tf.assign
op to the training functionmodel.train_function
. Using current Keras API, this can be done by providing afetches
argument toK.function()
when the training function is constructed.In
model._make_train_function()
, there's a line:The
fetches
argument containing thetf.assign
ops can be provided viamodel._function_kwargs
(only works after Keras 2.1.0).As an example:
Unless the number of samples can be divided by the batch size, the final batch will have a different size than other batches. So
K.variable()
andK.update()
can't be used in this case. You'll have to usetf.Variable(..., validate_shape=False)
andtf.assign(..., validate_shape=False)
instead.To verify the correctness of the saved arrays, you can add one line in
training.py
to print out the shuffled index array:The shuffled index array should be printed out during fitting:
And you can check if
cbk.targets
is the same asY[index_array]
:As you can see, there are two batches in
cbk.targets
(one "full batch" of size 8 and the final batch of size 2), and the row order is the same asY[index_array]
.One problem with @Yu-Yang's solution is that it relies on
model._function_kwargs
, which is not guaranteed to work as it is not part of the API. In particular, in TF2 with eager execution, session kwargs seem to be either not accepted at all or run preemptively due to eager mode.Therefore, here is my solution tested on
tensorflow==2.1.0
. The trick is to replacefetches
by a Keras metric, in which the assignment operations fromfetches
are made during training.This even enables a Keras-only solution if the batch size divides the number of samples; otherwise, another trick has to be applied when initializing TensorFlow variables with a
None
shape, similar tovalidate_shape=False
in earlier solutions (compare https://github.com/tensorflow/tensorflow/issues/35667).Importantly,
tf.keras
behaves differently fromkeras
(sometimes just ignoring assignments, or seeing variables as Keras symbolic tensors), so this updated solution takes care of both implementations (Keras==2.3.1
andtensorflow==2.1.0
).