Calculating precision, recall and F1 in Keras v2,

2019-08-20 18:05发布

问题:

There is already a question on how to obtain precision, recall and F1 scores in Keras v2, here is the method I'm using but the question is: am I doing it right?

First of all, F. Chollet says he removed these three metrics from version 2 of Keras because they were batch-based and hence not reliable. I'm following an idea by basque21 using a Callback with method on_epoch_end, isn't this normally batch-independent since calculated at epoch end (= after all batches have finished)?

Here is the code I'm using. In the model.fit method I add the argument callbacks=[metrics] and I define a dictionary myhistory and a class Metrics as follows (code adapted from basque21):

myhistory={}
myhistory['prec']=[]
myhistory['reca']=[]
myhistory['f1']=[]

class Metrics(keras.callbacks.Callback):
    def on_epoch_end(self, batch, logs={}):
        predict = numpy.asarray(self.model.predict(self.validation_data[0]))
        predict = a=numpy.array(predict.flatten()>=0.5,dtype=int)
        targ = self.validation_data[1]
        targ=numpy.array(targ.flatten()>=0.5,dtype=int)
        self.prf=precision_recall_fscore_support(targ, predict)
        print("Precision/recall/f1 class 0 is {}/{}/{}, precision/recall/f1 class 1 is {}/{}/{}".format(self.prf[0][0], self.prf[1][0], self.prf[2][0], self.prf[0][1], self.prf[1][1], self.prf[2][1]))
        myhistory['prec'].append(self.prf[0][1])
        myhistory['reca'].append(self.prf[1][1])
        myhistory['f1'].append(self.prf[2][1])
        return
metrics = Metrics()

Once fitting is finished I display everything on a common plot as follows:

import matplotlib.pyplot as plt

loss = history.history['loss']
val_loss = history.history['val_loss']
precision = [1-x for x in myhistory['prec']]
recall = [1-x for x in myhistory['reca']]
f_one = [1-x for x in myhistory['f1']]

epochs = range(1, len(loss) + 1)

# "bo" is for "blue dot"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b is for "solid blue line"
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.plot(epochs, precision, 'r', label='One minus precision')
plt.plot(epochs, recall, 'g', label='One minus recall')
plt.plot(epochs, f_one, 'm', label='One minus f1')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

I'm using “one minus precision,” etc., to have everything in the range close to zero. What I don't understand though is that, as you can see in the plot, although validation loss is increasing (due to overfitting), precision, recall seem to vary so that F1 stays relatively constant. What am I doing wrong?

The example is taken from Chapter 3 of F. Chollet's book, it is about IMDB texts classification and I am only displaying scores for class 1.