+1 vote
in Programming Languages by (56.8k points)

I am using scikit-learn's function RocCurveDisplay() to generate the AUC_ROC curve. It returns the correct curve, however, the plot does not show the y=x line. I want to add a dotted y=x line to the plot. How can I add it? I did not find any parameter in RocCurveDisplay() for it.

1 Answer

+2 votes
by (350k points)
edited by
 
Best answer

You need to add the dotted line to the plot using the plot() function of the matplotlib.pyplot module.

Here is an example to show how to use the plot() function with the RocCurveDisplay() function:

import matplotlib.pyplot as plt

from sklearn import metrics

#

# test data for AUC plot

#

y = [0, 1, 1, 0, 1, 0, 0, 1, 0, 0]

pred = [0.4, 0.89, 0.84, 0.74, 0.36, 0.94, 0.59, 0.71, 0.46, 0.12]

#

# caculate true positive rate and false positive rate

#

fpr, tpr, thresholds = metrics.roc_curve(y, pred)

roc_auc = metrics.auc(fpr, tpr)

display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc)

display.plot()

plt.title("AUC ROC curve")

plt.plot([0, 1], [0, 1], color='red', lw=1, linestyle='--')

plt.show()

The highlighted code shows the application of the plot() function for the dotted line.


...