Accuracy metric of a subsection of categories in K

2020-02-15 05:59发布

问题:

I've got a 3-class classification problem. Let's define them as classes 0,1 and 2. In my case, class 0 is not important - that is, whatever gets classified as class 0 is irrelevant. What's relevant, however, is accuracy, precision, recall, and error rate only for classes 1 and 2. I would like to define an accuracy metric that only looks at a subsection of the data that relates to 1 and 2 and gives me a measure of that as the model is training. I am not asking for code for accuracy or f1 or precision/recall - those I've found and can implement myself. What I'm asking is for code that can help select a subsection of the categories to perform these metrics on. Visually, with a confusion matrix: Given:

>  0   1   2
>0 10  3   4
>1 2   5   1
>2 8   5   9

I would like to only perform an accuracy measure in-training for the following subset only:

>  1   2
>1 5   1
>2 5   9

Possible idea: Concatenate a categorized, argmaxed y_pred and argmaxed y_true, drop all instances where 0 appears, re-unravel them back into a one_hot array, and do a simple binary accuracy on what remains?

Edit: I've tried to exclude the 0-class through this code, but it doesn't make sense. the 0-category gets effectively wrapped into the 1-category (that is, the true positives of both 0 and 1 end up being labeled as 1). Still looking for help - can anybody help out please?

#this solution does not work :(
def my_acc(y_true, y_pred):
#excluding the 0-category
y_true_cust = y_true[:,np.r_[1:3]]
y_pred_cust = y_pred[:,np.r_[1:3]]
#binary accuracy source code, slightly edited
y_pred_cat = Ker.round(y_pred_cust)
eql_cust = Ker.equal(y_true_cust, y_pred_cust)
return Ker.mean(eql_cust, axis = -1)

@ Ashwin Geet D'Sa

correct_guesses_3cat = 10 + 5 + 9
print(correct_guesses_3cat)
24

total_guesses_3cat = 10+3+4+2+5+1+8+5+9
print(total_guesses_3cat)
47

accuracy_3cat = 24/47
print(accuracy_3cat)
51.1 %

correct_guesses_2cat =5 + 9
print(correct_guesses_2cat)
14

total_guesses_2cat = 5+1+5+9
print(total_guesses_2cat)
20

accuracy_2cat = 14/20
print(accuracy_2cat)
70.0 %