4 个版本

0.1.3 2024 年 1 月 24 日
0.1.2 2023 年 11 月 7 日
0.1.1 2019 年 5 月 15 日
0.1.0 2018 年 12 月 19 日

机器学习 类别中排名第 18

Download history 283/week @ 2024-03-11 148/week @ 2024-03-18 35/week @ 2024-03-25 154/week @ 2024-04-01 118/week @ 2024-04-08 87/week @ 2024-04-15 121/week @ 2024-04-22 128/week @ 2024-04-29 100/week @ 2024-05-06 212/week @ 2024-05-13 236/week @ 2024-05-20 385/week @ 2024-05-27 481/week @ 2024-06-03 493/week @ 2024-06-10 77/week @ 2024-06-17 99/week @ 2024-06-24

每月下载量 1,182
用于 4 包(2 个直接使用)

Apache-2.0 许可

175KB
2.5K SLoC

MesaTEE GBDT-RS : 一个快速且安全的 GBDT 库,支持 TEEs,如 Intel SGX 和 ARM TrustZone

Build Status codecov

MesaTEE GBDT-RS 是用 Safe Rust 编写的梯度提升决策树库。库中没有不安全代码。

MesaTEE GBDT-RS 提供训练和推理功能。它可以使用由 xgboost 训练的模型进行推理任务。

新功能!MesaTEE GBDT-RS 的 论文 已被 IEEE S&P'19 接受!

支持的任务

支持训练和推理的任务

  1. 线性回归:使用 SquaredError 和 LAD 损失类型
  2. 二分类(标记为 1 和 -1):使用 LogLikelyhood 损失类型

与 xgboost 兼容

目前,MesaTEE GBDT-RS 支持 xgboost 训练的模型进行推理。模型应使用以下配置在 xgboost 中训练

  1. booster: gbtree
  2. objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" 或 "rank:pairwise".

我们已经测试过 MesaTEE GBDT-RS 与 xgboost 0.81 和 0.82 兼容

快速入门

训练步骤

  1. 设置配置
  2. 加载训练数据
  3. 训练模型
  4. (可选)保存模型

推理步骤

  1. 加载模型
  2. 加载测试数据
  3. 推理测试数据

示例

    use gbdt::config::Config;
    use gbdt::decision_tree::{DataVec, PredVec};
    use gbdt::gradient_boost::GBDT;
    use gbdt::input::{InputFormat, load};

    let mut cfg = Config::new();
    cfg.set_feature_size(22);
    cfg.set_max_depth(3);
    cfg.set_iterations(50);
    cfg.set_shrinkage(0.1);
    cfg.set_loss("LogLikelyhood"); 
    cfg.set_debug(true);
    cfg.set_data_sample_ratio(1.0);
    cfg.set_feature_sample_ratio(1.0);
    cfg.set_training_optimization_level(2);

    // load data
    let train_file = "dataset/agaricus-lepiota/train.txt";
    let test_file = "dataset/agaricus-lepiota/test.txt";

    let mut input_format = InputFormat::csv_format();
    input_format.set_feature_size(22);
    input_format.set_label_index(22);
    let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
    let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");

    // train and save model
    let mut gbdt = GBDT::new(&cfg);
    gbdt.fit(&mut train_dv);
    gbdt.save_model("gbdt.model").expect("failed to save the model");

    // load model and do inference
    let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
    let predicted: PredVec = model.predict(&test_dv);

示例代码

  • 线性回归:examples/iris.rs
  • 二分类:examples/agaricus-lepiota.rs

使用 xgboost 训练的模型

步骤

  1. 使用 xgboost 训练模型
  2. 使用 examples/convert_xgboost.py 转换模型
    • 用法:python convert_xgboost.py xgboost_model_path objective output_path
    • 注意 convert_xgboost.py 依赖于 xgboost Python 库。转换后的模型可以在没有 xgboost 的机器上使用
  3. 在 rust 代码中,调用 GBDT::load_from_xgboost(model_path, objective) 来加载模型
  4. 进行推理
  5. (可选) 调用 GBDT::save_model 将模型保存到 MesaTEE GBDT-RS 本地格式。

示例代码

  • "reg:linear": examples/test-xgb-reg-linear.rs
  • "reg:logistic": examples/test-xgb-reg-logistic.rs
  • "binary:logistic": examples/test-xgb-binary-logistic.rs
  • "binary:logitraw": examples/test-xgb-binary-logistic.rs
  • "multi:softprob": examples/test-xgb-multi-softprob.rs
  • "multi:softmax": examples/test-xgb-multi-softmax.rs
  • "rank:pairwise": examples/test-xgb-rank-pairwise.rs

多线程

训练

在此阶段,MesaTEE GBDT-RS 的训练是单线程的。

推理

相关的推理函数是单线程的。但它们是线程安全的。我们在 example/test-multithreads.rs 提供了一个使用多线程的推理示例

SGX 使用

由于 MesaTEE GBDT-RS 是用纯 rust 编写的,借助 rust-sgx-sdk,它可以在 sgx enclaves 中轻松使用

gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }

这将导入一个名为 gbdt_sgx 的 crate。如果您更喜欢 gbdt 作为常规

gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }

有关更多信息及具体示例,请参阅目录 sgx/gbdt-sgx-test

许可

Apache 2.0

作者

李天意 @n0b0dyCN [email protected]

李通欣 @litongxin1991 [email protected]

丁雨 @dingelish [email protected]

指导委员会

魏涛,张雨龙

致谢

感谢 @qiyiping 在之前的出色工作 gbdt。在我们开始这个项目之前,我们阅读了他的代码。

依赖关系

~0.9–2.2MB
~46K SLoC