Spaces:
Runtime error
Runtime error
| # Copyright (c) Guangsheng Bao. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import roc_curve, precision_recall_curve, auc | |
| # 15 colorblind-friendly colors | |
| COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442", | |
| "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73", | |
| "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"] | |
| def get_roc_metrics(real_preds, sample_preds): | |
| fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds) | |
| roc_auc = auc(fpr, tpr) | |
| return fpr.tolist(), tpr.tolist(), float(roc_auc) | |
| def get_precision_recall_metrics(real_preds, sample_preds): | |
| precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), | |
| real_preds + sample_preds) | |
| pr_auc = auc(recall, precision) | |
| return precision.tolist(), recall.tolist(), float(pr_auc) | |