2 个版本
0.1.1 | 2024 年 4 月 17 日 |
---|---|
0.1.0 | 2024 年 4 月 17 日 |
#240 在 机器学习
50KB
1K SLoC
关于
Sprout 是一个用 Rust 编写的 简单机器学习库,没有使用任何现有的 ML 或线性代数库。我创建 Sprout 是为了更好地理解 ML 概念。主要特性
- 全连接层
- 卷积层
- 小批量梯度下降
- 归一化
- 模型保存/加载到 JSON
使用方法
Sprout 使用包含 Layer 结构体的 Vec,将其传递给 Network 结构体,如下所示use Sprouts::{Layer::{Layer, LayerType}, network::Network, activation::ActivationFunction::*, loss_function::LossType::*}
let layers = vec![
Layer::dense([2, 3], Sigmoid),
Layer::dense([3, 1], Sigmoid),
];
// Network::new(layers, learning_rate, batch_size, loss_function);
let nn = Network::new(layers, 0.2, 1, MSE);
//Prints network's loss and epoch progress in the terminal
nn.dense_train(true);
//data: Vec<[Inputs, Outputs]>
let data: Vec<[Vec<f64>; 2]> = vec![
[vec![1.0, 0.0], vec![0.0]],
[vec![0.0, 0.0], vec![1.0]],
[vec![1.0, 1.0], vec![1.0]],
[vec![0.0, 1.0], vec![0.0]],
];
//dense_train(data, epochs)
nn.dense_train(data.clone(), 10000);
for i in 0..data.len() {
println!("Input: {:?} || Output: {:?} || Target: {:?}",data[i][0].clone(), nn.dense_forward(data[i][0].clone()), data[i][1].clone());
}
目前支持的层只有卷积和密集层,池化层将是下一个计划。
很快会完善 README...
许可协议
本项目遵循 MIT 许可协议。
依赖项
~3.5–4.5MB
~92K SLoC