Keras get model outputs after each batch

2020-07-28 05:21发布

问题:

I'm using a generator to make sequential training data for a hierarchical recurrent model, which needs the outputs of the previous batch to generate the inputs for the next batch. This is a similar situation to the Keras argument stateful=True which saves the hidden states for the next batch, except it's more complicated so I can't just use that as-is.

So far I tried putting a hack in the loss function:

def custom_loss(y_true, y_pred):
    global output_ref
    output_ref[0] = y_pred[0].eval(session=K.get_session())
    output_ref[1] = y_pred[1].eval(session=K.get_session())

but that didn't compile and I hope there's a better way. Will Keras callbacks be of any help?

回答1:

Learned from here:

model.compile(optimizer='adam')
# hack after compile
output_layers = [ 'gru' ]
s_name = 's'
model.metrics_names += [s_name]
model.metrics_tensors += [layer.output for layer in model.layers if layer.name in output_layers]

class my_callback(Callback):
    def on_batch_end(self, batch, logs=None):
        s_pred = logs[s_name]
        print('s_pred:', s_pred)
        return

model.fit(..., callbacks=[my_callback()])


回答2:

I use this in the Tensorflow version of Keras, but it should work in Keras without Tensorflow

import tensorflow as tf

class ModelOutput:
    ''' Class wrapper for a metric that stores the output passed to it '''
    def __init__(self, name):
        self.name = name
        self.y_true = None
        self.y_pred = None

    def save_output(self, y_true, y_pred):
        self.y_true = y_true
        self.y_pred = y_pred
        return tf.constant(True)

class ModelOutputCallback(tf.keras.callbacks.Callback):
  def __init__(self, model_outputs):
    tf.keras.callbacks.Callback.__init__(self)
    self.model_outputs = model_outputs

  def on_train_batch_end(self, batch, logs=None):
    #use self.model_outputs to get the outputs here

model_outputs = [
                ModelOutput('rbox_score_map'),
                ModelOutput('rbox_shapes'),
                ModelOutput('rbox_angles')
            ]

# Note the extra [] around m.save_output, this example is for a model with 
# 3 outputs, metrics must be a list of lists if you type it out
model.compile( ..., metrics=[[m.save_output] for m in self.model_outputs])

model.fit(..., callbacks=[ModelOutputCallback(model_outputs)])