5个版本 (3个稳定版)

1.0.2 2022年6月6日
1.0.1 2022年1月12日
0.1.1 2021年8月25日
0.1.0 2021年5月11日

#174 in 机器学习

Download history 26/week @ 2024-04-21 6/week @ 2024-04-28 35/week @ 2024-05-05 2/week @ 2024-05-12 7/week @ 2024-05-19 1/week @ 2024-05-26 48/week @ 2024-06-02 28/week @ 2024-06-09 15/week @ 2024-06-16 33/week @ 2024-06-23 12/week @ 2024-06-30 19/week @ 2024-07-07 19/week @ 2024-07-14 38/week @ 2024-07-21 110/week @ 2024-07-28 23/week @ 2024-08-04

每月190次下载

MIT/Apache

81KB
1.5K SLoC

eval-metrics

机器学习评估指标

crates.io License


设计

该库的目标是提供一组直观的函数,用于计算机器学习中常见的评估指标。指标分为分类和回归两个模块,分类模块支持二分类和多分类任务。这种二分类和多分类之间的区分是为了强调两种情况下某些指标之间存在微妙差异(即多分类指标通常需要平均方法)。由于各种数值原因,指标可能无法定义,在这些情况下使用Result类型来明确表示这一事实。

支持的指标

指标 任务 描述
准确率 二分类 二分类准确率
精确度 二分类 二分类精确度
召回率 二分类 二分类召回率
F-1 二分类 精确度和召回率的调和平均值
MCC 二分类 马修斯相关系数
ROC曲线 二分类 接收器操作特征曲线
AUC 二分类 ROC曲线下的面积
PR曲线 二分类 精确度-召回率曲线
AP 二分类 平均精确度
准确率 多分类 多分类准确率
精确度 多分类 多分类精确度
召回率 多分类 多分类召回率
F-1 多分类 多分类F-1
Rk 多分类 戈罗金(2004年)描述的K类别相关系数
M-AUC 多分类 汉德和蒂尔(2001年)描述的多分类AUC
RMSE 回归 均方根误差
MSE 回归 均方误差
MAE 回归 平均绝对误差
R-Square 回归 确定系数
相关度 回归 线性相关系数

用法

二分类

BinaryConfusionMatrix结构提供了计算常见二分类指标的函数。

use eval_metrics::error::EvalError;
use eval_metrics::classification::BinaryConfusionMatrix;

fn main() -> Result<(), EvalError> {
    // note: these scores could also be f32 values
    let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
    let labels = vec![false, false, true, false, true, false, false, true];
    let threshold = 0.5;

    // compute confusion matrix from scores and labels
    let matrix = BinaryConfusionMatrix::compute(&scores, &labels, threshold)?;

    // counts
    let tpc = matrix.tp_count;
    let fpc = matrix.fp_count;
    let tnc = matrix.tn_count;
    let fnc = matrix.fn_count;

    // metrics
    let acc = matrix.accuracy()?;
    let pre = matrix.precision()?;
    let rec = matrix.recall()?;
    let f1 = matrix.f1()?;
    let mcc = matrix.mcc()?;

    // print matrix to console
    println!("{}", matrix);
    Ok(())
}
                            o=========================o
                            |          Label          |
                            o=========================o
                            |  Positive  |  Negative  |
o==============o============o============|============o
|              |  Positive  |     2      |     2      |
|  Prediction  |============|------------|------------|
|              |  Negative  |     1      |     3      |
o==============o============o=========================o

除了来自混淆矩阵的指标外,还可以计算ROC曲线和PR曲线,提供AUC和AP等指标。

use eval_metrics::error::EvalError;
use eval_metrics::classification::{RocCurve, RocPoint, PrCurve, PrPoint};

fn main() -> Result<(), EvalError> {
    // note: these scores could also be f32 values
    let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
    let labels = vec![false, false, true, false, true, false, false, true];

    // construct roc curve
    let roc = RocCurve::compute(&scores, &labels)?;
    // compute auc
    let auc = roc.auc();
    // inspect roc curve points
    roc.points.iter().for_each(|point| {
        let tpr = point.tp_rate;
        let fpr = point.fp_rate;
        let thresh = point.threshold;
    });

    // construct pr curve
    let pr = PrCurve::compute(&scores, &labels)?;
    // compute average precision
    let ap = pr.ap();
    // inspect pr curve points
    pr.points.iter().for_each(|point| {
        let pre = point.precision;
        let rec = point.recall;
        let thresh = point.threshold;
    });
    Ok(())
}

多分类

MultiConfusionMatrix结构提供了计算常见多分类指标的函数。此外,对于这些指标中的几个,必须显式提供平均方法。

use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix, Averaging};

fn main() -> Result<(), EvalError> {
    // note: these scores could also be f32 values
    let scores = vec![
        vec![0.3, 0.1, 0.6],
        vec![0.5, 0.2, 0.3],
        vec![0.2, 0.7, 0.1],
        vec![0.3, 0.3, 0.4],
        vec![0.5, 0.1, 0.4],
        vec![0.8, 0.1, 0.1],
        vec![0.3, 0.5, 0.2]
    ];
    let labels = vec![2, 1, 1, 2, 0, 2, 0];

    // compute confusion matrix from scores and labels
    let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;

    // get counts
    let counts = &matrix.counts;

    // metrics
    let acc = matrix.accuracy()?;
    let mac_pre = matrix.precision(&Averaging::Macro)?;
    let wgt_pre = matrix.precision(&Averaging::Weighted)?;
    let mac_rec = matrix.recall(&Averaging::Macro)?;
    let wgt_rec = matrix.recall(&Averaging::Weighted)?;
    let mac_f1 = matrix.f1(&Averaging::Macro)?;
    let wgt_f1 = matrix.f1(&Averaging::Weighted)?;
    let rk = matrix.rk()?;

    // print matrix to console
    println!("{}", matrix);
    Ok(())
}
                           o===================================o
                           |               Label               |
                           o===================================o
                           |  Class-1  |  Class-2  |  Class-3  |
o==============o===========o===========|===========|===========o
|              |  Class-1  |     1     |     1     |     1     |
|              |===========|-----------|-----------|-----------|
|  Prediction  |  Class-2  |     1     |     1     |     0     |
|              |===========|-----------|-----------|-----------|
|              |  Class-3  |     0     |     0     |     2     |
o==============o===========o===================================o

除了这些全局指标外,还可以获得每个类的指标。

use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix};

fn main() -> Result<(), EvalError> {
    // note: these scores could also be f32 values
    let scores = vec![
        vec![0.3, 0.1, 0.6],
        vec![0.5, 0.2, 0.3],
        vec![0.2, 0.7, 0.1],
        vec![0.3, 0.3, 0.4],
        vec![0.5, 0.1, 0.4],
        vec![0.8, 0.1, 0.1],
        vec![0.3, 0.5, 0.2]
    ];
    let labels = vec![2, 1, 1, 2, 0, 2, 0];

    // compute confusion matrix from scores and labels
    let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
    
    // per-class metrics
    let pca = matrix.per_class_accuracy();
    let pcp = matrix.per_class_precision();
    let pcr = matrix.per_class_recall();
    let pcf = matrix.per_class_f1();
    let pcm = matrix.per_class_mcc();
    
    // print per-class metrics to console
    println!("{:?}", pca);
    println!("{:?}", pcp);
    println!("{:?}", pcr);
    println!("{:?}", pcf);
    println!("{:?}", pcm);
    Ok(())
}
[Ok(0.5714285714285714), Ok(0.7142857142857143), Ok(0.8571428571428571)]
[Ok(0.3333333333333333), Ok(0.5), Ok(1.0)]
[Ok(0.5), Ok(0.5), Ok(0.6666666666666666)]
[Ok(0.4), Ok(0.5), Ok(0.8)]
[Ok(0.09128709291752773), Ok(0.3), Ok(0.7302967433402215)]

除了从混淆矩阵派生的指标外,还提供了Hand和Till(2001)描述的M-AUC(多类AUC)指标作为独立函数。

let mauc = m_auc(&scores, &labels)?;

回归

所有回归指标都在一对分数和标签上操作。

use eval_metrics::error::EvalError;
use eval_metrics::regression::*;

fn main() -> Result<(), EvalError> {

    // note: these could also be f32 values
    let scores = vec![0.4, 0.7, -1.2, 2.5, 0.3];
    let labels = vec![0.2, 1.1, -0.9, 1.3, -0.2];

    // root mean squared error
    let rmse = rmse(&scores, &labels)?;
    // mean squared error
    let mse = mse(&scores, &labels)?;
    // mean absolute error
    let mae = mae(&scores, &labels)?;
    // coefficient of determination
    let rsq = rsq(&scores, &labels)?;
    // pearson correlation coefficient
    let corr = corr(&scores, &labels)?;
    Ok(())
}

依赖项