2 个版本

0.1.1 2024 年 4 月 17 日
0.1.0 2024 年 4 月 17 日

#240机器学习

MIT 许可协议

50KB
1K SLoC

Sprout Logo

关于

Sprout 是一个用 Rust 编写的 简单机器学习库,没有使用任何现有的 ML 或线性代数库。我创建 Sprout 是为了更好地理解 ML 概念。

主要特性

  • 全连接层
  • 卷积层
  • 小批量梯度下降
  • 归一化
  • 模型保存/加载到 JSON

使用方法

Sprout 使用包含 Layer 结构体的 Vec,将其传递给 Network 结构体,如下所示
use Sprouts::{Layer::{Layer, LayerType}, network::Network, activation::ActivationFunction::*, loss_function::LossType::*}

let layers = vec![
    Layer::dense([2, 3], Sigmoid),
    Layer::dense([3, 1], Sigmoid),
];

// Network::new(layers, learning_rate, batch_size, loss_function);
let nn = Network::new(layers, 0.2, 1, MSE);

//Prints network's loss and epoch progress in the terminal
nn.dense_train(true);

//data: Vec<[Inputs, Outputs]>
let data: Vec<[Vec<f64>; 2]> = vec![
    [vec![1.0, 0.0], vec![0.0]],
    [vec![0.0, 0.0], vec![1.0]],
    [vec![1.0, 1.0], vec![1.0]],
    [vec![0.0, 1.0], vec![0.0]],
];  

//dense_train(data, epochs)
nn.dense_train(data.clone(), 10000);

for i in 0..data.len() {
    println!("Input: {:?} || Output: {:?} || Target: {:?}",data[i][0].clone(), nn.dense_forward(data[i][0].clone()), data[i][1].clone());
}

目前支持的层只有卷积和密集层,池化层将是下一个计划。

很快会完善 README...

许可协议

本项目遵循 MIT 许可协议

依赖项

~3.5–4.5MB
~92K SLoC