How to plot ROC curve in Python

2019-03-08 01:57发布

I am trying to plot a ROC curve to evaluate the accuracy of a prediction model I developed in Python using logistic regression packages. I have computed the true positive rate as well as the false positive rate; however, I am unable to figure out how to plot these correctly using matplotlib and calculate the AUC value. How could I do that?

8条回答
smile是对你的礼貌
2楼-- · 2019-03-08 02:39

This is the simplest way to plot an ROC curve, given a set of ground truth labels and predicted probabilities. Best part is, it plots the ROC curve for ALL classes, so you get multiple neat-looking curves as well

import scikitplot as skplt
import matplotlib.pyplot as plt

y_true = # ground truth labels
y_probas = # predicted probabilities generated by sklearn classifier
skplt.metrics.plot_roc_curve(y_true, y_probas)
plt.show()

Here's a sample curve generated by plot_roc_curve. I used the sample digits dataset from scikit-learn so there are 10 classes. Notice that one ROC curve is plotted for each class.

ROC Curves

Disclaimer: Note that this uses the scikit-plot library, which I built.

查看更多
成全新的幸福
3楼-- · 2019-03-08 02:42

I have made a simple function included in a package for the ROC curve. I just started practicing machine learning so please also let me know if this code has any problem!

Have a look at the github readme file for more details! :)

https://github.com/bc123456/ROC

from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_ROC(y_train_true, y_train_prob, y_test_true, y_test_prob):
    '''
    a funciton to plot the ROC curve for train labels and test labels.
    Use the best threshold found in train set to classify items in test set.
    '''
    fpr_train, tpr_train, thresholds_train = roc_curve(y_train_true, y_train_prob, pos_label =True)
    sum_sensitivity_specificity_train = tpr_train + (1-fpr_train)
    best_threshold_id_train = np.argmax(sum_sensitivity_specificity_train)
    best_threshold = thresholds_train[best_threshold_id_train]
    best_fpr_train = fpr_train[best_threshold_id_train]
    best_tpr_train = tpr_train[best_threshold_id_train]
    y_train = y_train_prob > best_threshold

    cm_train = confusion_matrix(y_train_true, y_train)
    acc_train = accuracy_score(y_train_true, y_train)
    auc_train = roc_auc_score(y_train_true, y_train)

    print 'Train Accuracy: %s ' %acc_train
    print 'Train AUC: %s ' %auc_train
    print 'Train Confusion Matrix:'
    print cm_train

    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot(121)
    curve1 = ax.plot(fpr_train, tpr_train)
    curve2 = ax.plot([0, 1], [0, 1], color='navy', linestyle='--')
    dot = ax.plot(best_fpr_train, best_tpr_train, marker='o', color='black')
    ax.text(best_fpr_train, best_tpr_train, s = '(%.3f,%.3f)' %(best_fpr_train, best_tpr_train))
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve (Train), AUC = %.4f'%auc_train)

    fpr_test, tpr_test, thresholds_test = roc_curve(y_test_true, y_test_prob, pos_label =True)

    y_test = y_test_prob > best_threshold

    cm_test = confusion_matrix(y_test_true, y_test)
    acc_test = accuracy_score(y_test_true, y_test)
    auc_test = roc_auc_score(y_test_true, y_test)

    print 'Test Accuracy: %s ' %acc_test
    print 'Test AUC: %s ' %auc_test
    print 'Test Confusion Matrix:'
    print cm_test

    tpr_score = float(cm_test[1][1])/(cm_test[1][1] + cm_test[1][0])
    fpr_score = float(cm_test[0][1])/(cm_test[0][0]+ cm_test[0][1])

    ax2 = fig.add_subplot(122)
    curve1 = ax2.plot(fpr_test, tpr_test)
    curve2 = ax2.plot([0, 1], [0, 1], color='navy', linestyle='--')
    dot = ax2.plot(fpr_score, tpr_score, marker='o', color='black')
    ax2.text(fpr_score, tpr_score, s = '(%.3f,%.3f)' %(fpr_score, tpr_score))
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve (Test), AUC = %.4f'%auc_test)
    plt.savefig('ROC', dpi = 500)
    plt.show()

    return best_threshold

A sample roc graph produced by this code

查看更多
登录 后发表回答