4 个版本
0.1.3 | 2024 年 1 月 24 日 |
---|---|
0.1.2 | 2023 年 11 月 7 日 |
0.1.1 | 2019 年 5 月 15 日 |
0.1.0 | 2018 年 12 月 19 日 |
在 机器学习 类别中排名第 18
每月下载量 1,182
用于 4 个 包(2 个直接使用)
175KB
2.5K SLoC
MesaTEE GBDT-RS : 一个快速且安全的 GBDT 库,支持 TEEs,如 Intel SGX 和 ARM TrustZone
MesaTEE GBDT-RS 是用 Safe Rust 编写的梯度提升决策树库。库中没有不安全代码。
MesaTEE GBDT-RS 提供训练和推理功能。它可以使用由 xgboost 训练的模型进行推理任务。
新功能!MesaTEE GBDT-RS 的 论文 已被 IEEE S&P'19 接受!
支持的任务
支持训练和推理的任务
- 线性回归:使用 SquaredError 和 LAD 损失类型
- 二分类(标记为 1 和 -1):使用 LogLikelyhood 损失类型
与 xgboost 兼容
目前,MesaTEE GBDT-RS 支持 xgboost 训练的模型进行推理。模型应使用以下配置在 xgboost 中训练
- booster: gbtree
- objective: "reg:linear", "reg:logistic", "binary:logistic", "binary:logitraw", "multi:softprob", "multi:softmax" 或 "rank:pairwise".
我们已经测试过 MesaTEE GBDT-RS 与 xgboost 0.81 和 0.82 兼容
快速入门
训练步骤
- 设置配置
- 加载训练数据
- 训练模型
- (可选)保存模型
推理步骤
- 加载模型
- 加载测试数据
- 推理测试数据
示例
use gbdt::config::Config;
use gbdt::decision_tree::{DataVec, PredVec};
use gbdt::gradient_boost::GBDT;
use gbdt::input::{InputFormat, load};
let mut cfg = Config::new();
cfg.set_feature_size(22);
cfg.set_max_depth(3);
cfg.set_iterations(50);
cfg.set_shrinkage(0.1);
cfg.set_loss("LogLikelyhood");
cfg.set_debug(true);
cfg.set_data_sample_ratio(1.0);
cfg.set_feature_sample_ratio(1.0);
cfg.set_training_optimization_level(2);
// load data
let train_file = "dataset/agaricus-lepiota/train.txt";
let test_file = "dataset/agaricus-lepiota/test.txt";
let mut input_format = InputFormat::csv_format();
input_format.set_feature_size(22);
input_format.set_label_index(22);
let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");
// train and save model
let mut gbdt = GBDT::new(&cfg);
gbdt.fit(&mut train_dv);
gbdt.save_model("gbdt.model").expect("failed to save the model");
// load model and do inference
let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
let predicted: PredVec = model.predict(&test_dv);
示例代码
- 线性回归:examples/iris.rs
- 二分类:examples/agaricus-lepiota.rs
使用 xgboost 训练的模型
步骤
- 使用 xgboost 训练模型
- 使用 examples/convert_xgboost.py 转换模型
- 用法:python convert_xgboost.py xgboost_model_path objective output_path
- 注意 convert_xgboost.py 依赖于 xgboost Python 库。转换后的模型可以在没有 xgboost 的机器上使用
- 在 rust 代码中,调用 GBDT::load_from_xgboost(model_path, objective) 来加载模型
- 进行推理
- (可选) 调用 GBDT::save_model 将模型保存到 MesaTEE GBDT-RS 本地格式。
示例代码
- "reg:linear": examples/test-xgb-reg-linear.rs
- "reg:logistic": examples/test-xgb-reg-logistic.rs
- "binary:logistic": examples/test-xgb-binary-logistic.rs
- "binary:logitraw": examples/test-xgb-binary-logistic.rs
- "multi:softprob": examples/test-xgb-multi-softprob.rs
- "multi:softmax": examples/test-xgb-multi-softmax.rs
- "rank:pairwise": examples/test-xgb-rank-pairwise.rs
多线程
训练
在此阶段,MesaTEE GBDT-RS 的训练是单线程的。
推理
相关的推理函数是单线程的。但它们是线程安全的。我们在 example/test-multithreads.rs 提供了一个使用多线程的推理示例
SGX 使用
由于 MesaTEE GBDT-RS 是用纯 rust 编写的,借助 rust-sgx-sdk,它可以在 sgx enclaves 中轻松使用
gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }
这将导入一个名为 gbdt_sgx
的 crate。如果您更喜欢 gbdt
作为常规
gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }
有关更多信息及具体示例,请参阅目录 sgx/gbdt-sgx-test
。
许可
Apache 2.0
作者
李天意 @n0b0dyCN [email protected]
李通欣 @litongxin1991 [email protected]
丁雨 @dingelish [email protected]
指导委员会
魏涛,张雨龙
致谢
感谢 @qiyiping 在之前的出色工作 gbdt。在我们开始这个项目之前,我们阅读了他的代码。
依赖关系
~0.9–2.2MB
~46K SLoC