5个版本
使用旧的Rust 2015
0.1.3 | 2017年11月2日 |
---|---|
0.1.2 | 2017年11月2日 |
0.1.1 | 2017年11月2日 |
0.1.0 | 2017年10月28日 |
0.0.0 | 2017年8月12日 |
#632 in 机器学习
27KB
618 行
neural_network
这个crate是一个模块化神经网络的实现。它旨在使神经网络任何学习方法的实现变得非常简单且没有运行时成本。如果您想了解如何实现任何学习技术,可以查看src/back_prop文件夹,其中包含反向传播的实现。如果您只是想使用神经网络解决问题,可以参考下方的教程,了解如何将其添加到您的项目以及如何使用它。
将neural_network添加到您的项目
将以下行添加到您的Cargo.toml文件中的依赖项
neural_network = "0.1.3"
然后,在您的src/main.rs或src/lib.rs文件的开始处添加以下内容
extern crate neural_network;
然后您就可以使用了!
反向传播的示例
以下是如何使用反向传播来训练神经网络以返回数字的正弦和余弦值的示例。您还将看到如何将网络保存到文件并重新加载。
use neural_network::back_prop::prelude::*;
use std::fs::{File, remove_file};
use rand;
fn main() {
let file_name = "back_prop_net.nn";
let mut net = BackProp::new(1, &[50, 50, 50], 2,
0.05, 0.1, 1.0,
Tanh::activation, Tanh::derivative,
Tanh::activation, Tanh::derivative);
{
// Generate the training data
let mut train_inputs: Vec<[f64;1]> = Vec::with_capacity(1000);
let mut train_targets: Vec<[f64;2]> = Vec::with_capacity(1000);
for _ in 0..1000 {
let num: f64 = rand::random();
train_inputs.push([num]);
train_targets.push([num.sin(), num.cos()]);
}
// Generate the testing data
let mut test_inputs: Vec<[f64;1]> = Vec::with_capacity(100);
let mut test_targets: Vec<[f64;2]> = Vec::with_capacity(100);
for _ in 0..100 {
let num: f64 = rand::random();
let sc = num.sin_cos();
test_inputs.push([num]);
test_targets.push([sc.0, sc.1]);
}
let train_inputs: Vec<&[f64]> = train_inputs.iter().map(|n|n as &[f64]).collect();
let train_targets: Vec<&[f64]> = train_targets.iter().map(|n|n as &[f64]).collect();
let test_inputs: Vec<&[f64]> = test_inputs.iter().map(|n|n as &[f64]).collect();
let test_targets: Vec<&[f64]> = test_targets.iter().map(|n|n as &[f64]).collect();
let result = net.train(0.001, None, None,
&train_inputs, &train_targets,
&test_inputs, &test_targets);
assert!(result.min_error <= 0.001);
}
net.save(&mut File::create(file_name).unwrap()).unwrap();
let loaded = BackProp::load(&mut File::open(file_name).unwrap(), Tanh::activation, Tanh::derivative,
Tanh::activation, Tanh::derivative, ).unwrap();
assert_eq!(net, loaded);
remove_file(file_name).unwrap();
}
依赖项
~330–560KB