I'm using a generator to make sequential training data for a hierarchical recurrent model, which needs the outputs of the previous batch to generate the inputs for the next batch. This is a similar situation to the Keras argument stateful=True
which saves the hidden states for the next batch, except it's more complicated so I can't just use that as-is.
So far I tried putting a hack in the loss function:
def custom_loss(y_true, y_pred):
global output_ref
output_ref[0] = y_pred[0].eval(session=K.get_session())
output_ref[1] = y_pred[1].eval(session=K.get_session())
but that didn't compile and I hope there's a better way. Will Keras callbacks be of any help?
Learned from here:
model.compile(optimizer='adam')
# hack after compile
output_layers = [ 'gru' ]
s_name = 's'
model.metrics_names += [s_name]
model.metrics_tensors += [layer.output for layer in model.layers if layer.name in output_layers]
class my_callback(Callback):
def on_batch_end(self, batch, logs=None):
s_pred = logs[s_name]
print('s_pred:', s_pred)
return
model.fit(..., callbacks=[my_callback()])
I use this in the Tensorflow version of Keras, but it should work in Keras without Tensorflow
import tensorflow as tf
class ModelOutput:
''' Class wrapper for a metric that stores the output passed to it '''
def __init__(self, name):
self.name = name
self.y_true = None
self.y_pred = None
def save_output(self, y_true, y_pred):
self.y_true = y_true
self.y_pred = y_pred
return tf.constant(True)
class ModelOutputCallback(tf.keras.callbacks.Callback):
def __init__(self, model_outputs):
tf.keras.callbacks.Callback.__init__(self)
self.model_outputs = model_outputs
def on_train_batch_end(self, batch, logs=None):
#use self.model_outputs to get the outputs here
model_outputs = [
ModelOutput('rbox_score_map'),
ModelOutput('rbox_shapes'),
ModelOutput('rbox_angles')
]
# Note the extra [] around m.save_output, this example is for a model with
# 3 outputs, metrics must be a list of lists if you type it out
model.compile( ..., metrics=[[m.save_output] for m in self.model_outputs])
model.fit(..., callbacks=[ModelOutputCallback(model_outputs)])