9 个不稳定版本

使用旧版 Rust 2015

0.5.0 2018年7月29日
0.4.3 2018年5月22日
0.4.2 2016年8月28日
0.3.1 2016年3月1日
0.1.0 2015年12月6日

#1889算法

Download history 38/week @ 2024-03-12 37/week @ 2024-03-19 30/week @ 2024-03-26 87/week @ 2024-04-02 24/week @ 2024-04-09 40/week @ 2024-04-16 33/week @ 2024-04-23 33/week @ 2024-04-30 58/week @ 2024-05-07 56/week @ 2024-05-14 31/week @ 2024-05-21 81/week @ 2024-05-28 37/week @ 2024-06-04 31/week @ 2024-06-11 42/week @ 2024-06-18 52/week @ 2024-06-25

167 每月下载量
用于 2 crates

Apache-2.0

1MB
19K SLoC

Rust 16K SLoC // 0.0% comments C++ 2.5K SLoC // 0.1% comments Python 306 SLoC // 0.4% comments

rustlearn

Circle CI Crates.io

Rust 机器学习包。

有关完整使用说明,请参阅 API 文档

简介

本软件包包含了一些常见机器学习算法的有效实现。

目前,rustlearn 使用自己的基本稠密和稀疏数组类型,但一旦在该领域出现明确的胜者,我将乐意使用更稳健的解决方案。

特性

矩阵原语

模型

所有模型都支持在稠密和稀疏数据上拟合和预测,并且在准确性和性能方面应与 Python sklearn 实现大致相当。

交叉验证

度量

并行化

许多模型支持并行模型拟合和预测。

模型序列化

模型序列化支持通过 serde

使用 rustlearn

使用方法应该很简单。

  • 导入所有线性代数原语和常见特性的预导入模块
use rustlearn::prelude::*;
  • 从子模块中导入单个模型和工具
use rustlearn::prelude::*;

use rustlearn::linear_models::sgdclassifier::Hyperparameters;
// more imports

示例

逻辑回归

use rustlearn::prelude::*;
use rustlearn::datasets::iris;
use rustlearn::cross_validation::CrossValidation;
use rustlearn::linear_models::sgdclassifier::Hyperparameters;
use rustlearn::metrics::accuracy_score;


let (X, y) = iris::load_data();

let num_splits = 10;
let num_epochs = 5;

let mut accuracy = 0.0;

for (train_idx, test_idx) in CrossValidation::new(X.rows(), num_splits) {

    let X_train = X.get_rows(&train_idx);
    let y_train = y.get_rows(&train_idx);
    let X_test = X.get_rows(&test_idx);
    let y_test = y.get_rows(&test_idx);

    let mut model = Hyperparameters::new(X.cols())
                                    .learning_rate(0.5)
                                    .l2_penalty(0.0)
                                    .l1_penalty(0.0)
                                    .one_vs_rest();

    for _ in 0..num_epochs {
        model.fit(&X_train, &y_train).unwrap();
    }

    let prediction = model.predict(&X_test).unwrap();
    accuracy += accuracy_score(&y_test, &prediction);
}

accuracy /= num_splits as f32;

随机森林

use rustlearn::prelude::*;

use rustlearn::ensemble::random_forest::Hyperparameters;
use rustlearn::datasets::iris;
use rustlearn::trees::decision_tree;

let (data, target) = iris::load_data();

let mut tree_params = decision_tree::Hyperparameters::new(data.cols());
tree_params.min_samples_split(10)
    .max_features(4);

let mut model = Hyperparameters::new(tree_params, 10)
    .one_vs_rest();

model.fit(&data, &target).unwrap();

// Optionally serialize and deserialize the model

// let encoded = bincode::serialize(&model).unwrap();
// let decoded: OneVsRestWrapper<RandomForest> = bincode::deserialize(&encoded).unwrap();

let prediction = model.predict(&data).unwrap();

贡献

欢迎提交拉取请求。

要运行基本测试,请运行 cargo test

运行 cargo test --features "all_tests" --release 将运行所有测试,包括生成和慢速测试。运行 cargo bench --features bench(仅在夜间分支上)将运行基准测试。

依赖项

~0.8–1.6MB
~31K SLoC