1个不稳定版本
0.1.0 | 2022年12月31日 |
---|
#499 in 机器学习
150KB
3K SLoC
Rust中的机器学习
通过实现经典机器学习算法来学习Rust编程语言。该项目完全独立完成,不依赖任何第三方库,充当一个 引导机器学习库。
❗❗❗:积极寻求代码审查,并欢迎对修复错误或代码重构的建议。请随时分享您的想法。很高兴接受建议!
基础知识
- NdArray模块,正如其名。它实现了任意维度的
broadcast
、矩阵运算
、permute
等。由于Rust的自动向量化,矩阵乘法使用了SIMD。 - Dataset模块,支持自定义加载数据、重新格式化、
normalize
、shuffle
和Dataloader
。提供了一些流行的数据集预处理配方。
算法
- 决策树,支持分类和回归任务。提供了如
gini
或entropy
的信息增益。 - 逻辑回归,支持正则化(
Lasso
、Ridge
和L-inf
) - 线性回归,与逻辑回归相同,但用于回归任务。
- 朴素贝叶斯,可以自由处理离散或连续的特征值。
- SVM,使用SGD和Hinge Loss进行优化,具有线性核。
- nn模块,包含
linear(MLP)
和一些可以自由堆叠并通过梯度反向传播进行优化的activation
函数。 - KNN,支持
KdTree
和普通的BruteForceSearch
。 - K-Means,使用无监督学习方法聚类数据
开始
让我们使用KNN算法来解决一个分类任务。更多示例可以在examples
目录中找到。
-
为测试创建一些合成数据
use std::collections::HashMap; let features = vec![ vec![0.6, 0.7, 0.8], vec![0.7, 0.8, 0.9], vec![0.1, 0.2, 0.3], ]; let labels = vec![0, 0, 1]; // so it is a binary classifiction task, 0 is for the large label, 1 is for the small label let mut label_map = HashMap::new(); label_map.insert(0, "large".to_string()); label_map.insert(1, "small".to_string());
-
将数据转换为
dataset
use mlinrust::dataset::Dataset; let dataset = Dataset::new(features, labels, Some(label_map));
-
将数据集拆分为
train
和valid
集,并通过标准归一化进行归一化let mut temp = dataset.split_dataset(vec![2.0, 1.0], 0); // [2.0, 1.0] is the split fraction, 0 is the seed let (mut train_dataset, mut valid_dataset) = (temp.remove(0), temp.remove(0)); use mlinrust::dataset::utils::{normalize_dataset, ScalerType}; normalize_dataset(&mut train_dataset, ScalerType::Standard); normalize_dataset(&mut valid_dataset, ScalerType::Standard);
-
使用
KdTree
构建和训练我们的 KNN 模型use mlinrust::model::knn::{KNNAlg, KNNModel, KNNWeighting}; // KdTree is one implementation of KNN; 1 defines the k of neighbours; Weighting decides the way of ensemble prediction; train_dataset is for training KNN; Some(2) is the param of minkowski distance let model = KNNModel::new(KNNAlg::KdTree, 1, Some(KNNWeighting::Distance), train_dataset, Some(2));
-
评估模型
use mlinrust::utils::evaluate; let (correct, acc) = evaluate(&valid_dataset, &model); println!("evaluate results\ncorrect {correct} / total {}, acc = {acc:.5}", test_dataset.len());
待办
- 模型权重序列化以进行保存和加载
- Boosting/bagging
- 使用多线程进行矩阵乘法
- 重构代码,真诚地请求资深开发者的评论
参考
- scikit-learn
- 周志华教授的《机器学习西瓜书》,链接:[机器学习西瓜书](https://cs.nju.edu.cn/zhouzh/zhouzh.files/publication/MLbook2016.htm),作者:[周志华教授](https://cs.nju.edu.cn/zhouzh/index.htm)
谢谢
Rust 社区。我从 rust-lang Discord 收到了很多帮助。
许可
在 GPL-v3 许可下。商业用途严格禁止。