5个版本 (3个稳定版)
1.0.2 | 2022年6月6日 |
---|---|
1.0.1 | 2022年1月12日 |
0.1.1 | 2021年8月25日 |
0.1.0 | 2021年5月11日 |
#174 in 机器学习
每月190次下载
81KB
1.5K SLoC
eval-metrics
机器学习评估指标
设计
该库的目标是提供一组直观的函数,用于计算机器学习中常见的评估指标。指标分为分类和回归两个模块,分类模块支持二分类和多分类任务。这种二分类和多分类之间的区分是为了强调两种情况下某些指标之间存在微妙差异(即多分类指标通常需要平均方法)。由于各种数值原因,指标可能无法定义,在这些情况下使用Result
类型来明确表示这一事实。
支持的指标
指标 | 任务 | 描述 |
---|---|---|
准确率 | 二分类 | 二分类准确率 |
精确度 | 二分类 | 二分类精确度 |
召回率 | 二分类 | 二分类召回率 |
F-1 | 二分类 | 精确度和召回率的调和平均值 |
MCC | 二分类 | 马修斯相关系数 |
ROC曲线 | 二分类 | 接收器操作特征曲线 |
AUC | 二分类 | ROC曲线下的面积 |
PR曲线 | 二分类 | 精确度-召回率曲线 |
AP | 二分类 | 平均精确度 |
准确率 | 多分类 | 多分类准确率 |
精确度 | 多分类 | 多分类精确度 |
召回率 | 多分类 | 多分类召回率 |
F-1 | 多分类 | 多分类F-1 |
Rk | 多分类 | 戈罗金(2004年)描述的K类别相关系数 |
M-AUC | 多分类 | 汉德和蒂尔(2001年)描述的多分类AUC |
RMSE | 回归 | 均方根误差 |
MSE | 回归 | 均方误差 |
MAE | 回归 | 平均绝对误差 |
R-Square | 回归 | 确定系数 |
相关度 | 回归 | 线性相关系数 |
用法
二分类
BinaryConfusionMatrix
结构提供了计算常见二分类指标的函数。
use eval_metrics::error::EvalError;
use eval_metrics::classification::BinaryConfusionMatrix;
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
let labels = vec![false, false, true, false, true, false, false, true];
let threshold = 0.5;
// compute confusion matrix from scores and labels
let matrix = BinaryConfusionMatrix::compute(&scores, &labels, threshold)?;
// counts
let tpc = matrix.tp_count;
let fpc = matrix.fp_count;
let tnc = matrix.tn_count;
let fnc = matrix.fn_count;
// metrics
let acc = matrix.accuracy()?;
let pre = matrix.precision()?;
let rec = matrix.recall()?;
let f1 = matrix.f1()?;
let mcc = matrix.mcc()?;
// print matrix to console
println!("{}", matrix);
Ok(())
}
o=========================o
| Label |
o=========================o
| Positive | Negative |
o==============o============o============|============o
| | Positive | 2 | 2 |
| Prediction |============|------------|------------|
| | Negative | 1 | 3 |
o==============o============o=========================o
除了来自混淆矩阵的指标外,还可以计算ROC曲线和PR曲线,提供AUC和AP等指标。
use eval_metrics::error::EvalError;
use eval_metrics::classification::{RocCurve, RocPoint, PrCurve, PrPoint};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
let labels = vec![false, false, true, false, true, false, false, true];
// construct roc curve
let roc = RocCurve::compute(&scores, &labels)?;
// compute auc
let auc = roc.auc();
// inspect roc curve points
roc.points.iter().for_each(|point| {
let tpr = point.tp_rate;
let fpr = point.fp_rate;
let thresh = point.threshold;
});
// construct pr curve
let pr = PrCurve::compute(&scores, &labels)?;
// compute average precision
let ap = pr.ap();
// inspect pr curve points
pr.points.iter().for_each(|point| {
let pre = point.precision;
let rec = point.recall;
let thresh = point.threshold;
});
Ok(())
}
多分类
MultiConfusionMatrix
结构提供了计算常见多分类指标的函数。此外,对于这些指标中的几个,必须显式提供平均方法。
use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix, Averaging};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![
vec![0.3, 0.1, 0.6],
vec![0.5, 0.2, 0.3],
vec![0.2, 0.7, 0.1],
vec![0.3, 0.3, 0.4],
vec![0.5, 0.1, 0.4],
vec![0.8, 0.1, 0.1],
vec![0.3, 0.5, 0.2]
];
let labels = vec![2, 1, 1, 2, 0, 2, 0];
// compute confusion matrix from scores and labels
let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
// get counts
let counts = &matrix.counts;
// metrics
let acc = matrix.accuracy()?;
let mac_pre = matrix.precision(&Averaging::Macro)?;
let wgt_pre = matrix.precision(&Averaging::Weighted)?;
let mac_rec = matrix.recall(&Averaging::Macro)?;
let wgt_rec = matrix.recall(&Averaging::Weighted)?;
let mac_f1 = matrix.f1(&Averaging::Macro)?;
let wgt_f1 = matrix.f1(&Averaging::Weighted)?;
let rk = matrix.rk()?;
// print matrix to console
println!("{}", matrix);
Ok(())
}
o===================================o
| Label |
o===================================o
| Class-1 | Class-2 | Class-3 |
o==============o===========o===========|===========|===========o
| | Class-1 | 1 | 1 | 1 |
| |===========|-----------|-----------|-----------|
| Prediction | Class-2 | 1 | 1 | 0 |
| |===========|-----------|-----------|-----------|
| | Class-3 | 0 | 0 | 2 |
o==============o===========o===================================o
除了这些全局指标外,还可以获得每个类的指标。
use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![
vec![0.3, 0.1, 0.6],
vec![0.5, 0.2, 0.3],
vec![0.2, 0.7, 0.1],
vec![0.3, 0.3, 0.4],
vec![0.5, 0.1, 0.4],
vec![0.8, 0.1, 0.1],
vec![0.3, 0.5, 0.2]
];
let labels = vec![2, 1, 1, 2, 0, 2, 0];
// compute confusion matrix from scores and labels
let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
// per-class metrics
let pca = matrix.per_class_accuracy();
let pcp = matrix.per_class_precision();
let pcr = matrix.per_class_recall();
let pcf = matrix.per_class_f1();
let pcm = matrix.per_class_mcc();
// print per-class metrics to console
println!("{:?}", pca);
println!("{:?}", pcp);
println!("{:?}", pcr);
println!("{:?}", pcf);
println!("{:?}", pcm);
Ok(())
}
[Ok(0.5714285714285714), Ok(0.7142857142857143), Ok(0.8571428571428571)]
[Ok(0.3333333333333333), Ok(0.5), Ok(1.0)]
[Ok(0.5), Ok(0.5), Ok(0.6666666666666666)]
[Ok(0.4), Ok(0.5), Ok(0.8)]
[Ok(0.09128709291752773), Ok(0.3), Ok(0.7302967433402215)]
除了从混淆矩阵派生的指标外,还提供了Hand和Till(2001)描述的M-AUC(多类AUC)指标作为独立函数。
let mauc = m_auc(&scores, &labels)?;
回归
所有回归指标都在一对分数和标签上操作。
use eval_metrics::error::EvalError;
use eval_metrics::regression::*;
fn main() -> Result<(), EvalError> {
// note: these could also be f32 values
let scores = vec![0.4, 0.7, -1.2, 2.5, 0.3];
let labels = vec![0.2, 1.1, -0.9, 1.3, -0.2];
// root mean squared error
let rmse = rmse(&scores, &labels)?;
// mean squared error
let mse = mse(&scores, &labels)?;
// mean absolute error
let mae = mae(&scores, &labels)?;
// coefficient of determination
let rsq = rsq(&scores, &labels)?;
// pearson correlation coefficient
let corr = corr(&scores, &labels)?;
Ok(())
}