9 个不稳定版本
使用旧版 Rust 2015
0.5.0 | 2018年7月29日 |
---|---|
0.4.3 | 2018年5月22日 |
0.4.2 | 2016年8月28日 |
0.3.1 | 2016年3月1日 |
0.1.0 | 2015年12月6日 |
#1889 在 算法 中
167 每月下载量
用于 2 crates
1MB
19K SLoC
rustlearn
Rust 机器学习包。
有关完整使用说明,请参阅 API 文档。
简介
本软件包包含了一些常见机器学习算法的有效实现。
目前,rustlearn
使用自己的基本稠密和稀疏数组类型,但一旦在该领域出现明确的胜者,我将乐意使用更稳健的解决方案。
特性
矩阵原语
模型
所有模型都支持在稠密和稀疏数据上拟合和预测,并且在准确性和性能方面应与 Python sklearn
实现大致相当。
交叉验证
度量
并行化
许多模型支持并行模型拟合和预测。
模型序列化
模型序列化支持通过 serde
。
使用 rustlearn
使用方法应该很简单。
- 导入所有线性代数原语和常见特性的预导入模块
use rustlearn::prelude::*;
- 从子模块中导入单个模型和工具
use rustlearn::prelude::*;
use rustlearn::linear_models::sgdclassifier::Hyperparameters;
// more imports
示例
逻辑回归
use rustlearn::prelude::*;
use rustlearn::datasets::iris;
use rustlearn::cross_validation::CrossValidation;
use rustlearn::linear_models::sgdclassifier::Hyperparameters;
use rustlearn::metrics::accuracy_score;
let (X, y) = iris::load_data();
let num_splits = 10;
let num_epochs = 5;
let mut accuracy = 0.0;
for (train_idx, test_idx) in CrossValidation::new(X.rows(), num_splits) {
let X_train = X.get_rows(&train_idx);
let y_train = y.get_rows(&train_idx);
let X_test = X.get_rows(&test_idx);
let y_test = y.get_rows(&test_idx);
let mut model = Hyperparameters::new(X.cols())
.learning_rate(0.5)
.l2_penalty(0.0)
.l1_penalty(0.0)
.one_vs_rest();
for _ in 0..num_epochs {
model.fit(&X_train, &y_train).unwrap();
}
let prediction = model.predict(&X_test).unwrap();
accuracy += accuracy_score(&y_test, &prediction);
}
accuracy /= num_splits as f32;
随机森林
use rustlearn::prelude::*;
use rustlearn::ensemble::random_forest::Hyperparameters;
use rustlearn::datasets::iris;
use rustlearn::trees::decision_tree;
let (data, target) = iris::load_data();
let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
.max_features(4);
let mut model = Hyperparameters::new(tree_params, 10)
.one_vs_rest();
model.fit(&data, &target).unwrap();
// Optionally serialize and deserialize the model
// let encoded = bincode::serialize(&model).unwrap();
// let decoded: OneVsRestWrapper<RandomForest> = bincode::deserialize(&encoded).unwrap();
let prediction = model.predict(&data).unwrap();
贡献
欢迎提交拉取请求。
要运行基本测试,请运行 cargo test
。
运行 cargo test --features "all_tests" --release
将运行所有测试,包括生成和慢速测试。运行 cargo bench --features bench
(仅在夜间分支上)将运行基准测试。
依赖项
~0.8–1.6MB
~31K SLoC