4 个版本 (2 个重大更新)
0.3.1 | 2023 年 3 月 13 日 |
---|---|
0.3.0 | 2023 年 3 月 13 日 |
0.2.0 | 2021 年 11 月 17 日 |
0.1.0 | 2021 年 11 月 12 日 |
#298 在 机器学习
每月 下载 28 次
13KB
246 行
[email protected] data:image/s3,"s3://crabby-images/aefdd/aefdd5564e944cc313a2d55e7fb6d3cbbd49732f" alt="构建状态"
data:image/s3,"s3://crabby-images/6aa03/6aa03c240b2c978a2c1524abcf06e54a72ec3491" alt="codecov"
vinyana - 巴利语中的“心”。
目标
本项目的目标是创建一个易于使用和理解的 Rust 神经网络库。
使用方法
use std::path::Path;
use rand::prelude::*;
use vinyana::{activation::ActivationType, NeuralNetwork};
fn main() {
let mut nn = NeuralNetwork::new(vec![2, 2, 1]);
nn.set_learning_rate(0.01); // default is 0.01 but you can change it
nn.set_activation(ActivationType::Tanh); // default is Sigmoid but you can change it
// We will train this network with 4 scenarios of XOR problem
let scenarios = vec![
(vec![1.0, 1.0], vec![0.0f32]),
(vec![0.0, 1.0], vec![1.0]),
(vec![1.0, 0.0], vec![1.0]),
(vec![0.0, 0.0], vec![0.0]),
];
let mut rng = thread_rng();
for _ in 0..500000 {
let random = rng.gen_range(0..4) as usize;
let (train_data, target_data) = scenarios.get(random).unwrap();
// we will pick a random scenario from the dataset and feed it to the network with the expected target
nn.train(train_data.clone(), target_data.clone())
}
let result = nn.predict(vec![1.0, 0.0]);
println!("Result: {:?}", result);
// we can store our trained model and play with it later
nn.save(Path::new("xor_model.nn")).unwrap();
}
// Load your model from file
let nn = NeuralNetwork::load(Path::new("xor_model.nn")).unwrap();
let result = nn.predict(vec![1.0, 1.0]);
println!("{:?}", result);
依赖关系
~2.5–3.5MB
~68K SLoC