7个版本 (破坏性更新)
0.7.0 | 2023年5月1日 |
---|---|
0.6.0 | 2023年4月2日 |
0.5.0 | 2023年3月8日 |
0.4.0 | 2023年3月6日 |
0.1.0 | 2023年2月21日 |
#415 in 机器学习
每月21次下载
用于 learnwell
81KB
1.5K SLoC
runnt (rust neural net)
非常简单的全连接神经网络。
当你只想用最少的依赖和代码行数拼凑一些东西时。
目标是创建一个全连接网络,在数据上运行,用大约10行代码得到结果
由于找不到一个没有外部依赖且易于使用的rust库,所以创建了此库。
如果您发现任何错误或优化,欢迎提出问题或PR。
功能
- 全连接神经网络
- 最小依赖
- 不需要外部静态库/dlls
- 回归和分类
- 可以定义层的大小
- 可以定义激活类型
- 可以保存/加载模型
- 随机、小批量、梯度下降
- 正则化
- 数据集管理器
- csv
- onehot编码
- 归一化
- 报告
如何使用
简单示例
你需要的是NN和数据
//XOR
use runnt::{nn::NN,activation::ActivationType};
let inputs = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]];
let outputs = [[0.], [1.], [1.], [0.]];
let mut nn = NN::new(&[2, 8, 1])
.with_learning_rate(0.2)
.with_hidden_type(ActivationType::Tanh)
.with_output_type(ActivationType::Linear);
for i in 0..5000 {
nn.fit_one(&inputs[i % 4], &outputs[i % 4]);
}
带数据集和报告的简单示例
Dataset
使得加载数据和转换数据变得容易一些
train
使得运行时代和报告变得容易
在< 10行代码内完成完整的神经网络并报告
let set = Dataset::builder()
.read_csv("examples/data/iris.csv")
.add_input_columns(&[0, 1, 2, 3], Conversion::NormaliseMean)
.add_target_columns(&[4], Conversion::OneHot)
.allocate_to_test_data(0.2)
.build();
let mut net = NN::new(&[set.input_size(), 32, set.target_size()]).with_learning_rate(0.15);
net.train(&set, 1000, 8, 100, ReportAccuracy::CorrectClassification);
带有数据集、报告和保存
let set = Dataset::builder()
.read_csv(r"/temp/diamonds.csv")
.allocate_to_test_data(0.2)
.add_input_columns(&[0, 4, 5, 7, 8, 9], Conversion::NormaliseMean)
.add_input_columns(&[1, 2, 3], Conversion::OneHot)
.add_target_columns(
&[6],
Conversion::Function(|f| f.parse::<f32>().unwrap_or_default() / 1_000.),
)
.build();
let save_path = r"network.txt";
let mut net = if std::path::PathBuf::from_str(save_path).unwrap().exists() {
NN::load(save_path)
} else {
NN::new(&[set.input_size(), 32, set.target_size()])
};
//run for 100 epochs, with batch size 32 and report every 10 epochs
net.train(&set, 100, 32, 10, ReportAccuracy::RSquared);
net.save(save_path);
依赖关系
~1.5MB
~26K SLoC