3 个不稳定版本
0.2.1 | 2019 年 3 月 19 日 |
---|---|
0.2.0 | 2019 年 3 月 19 日 |
0.1.0 | 2019 年 3 月 17 日 |
#588 in 机器学习
13KB
198 行
一个简单的神经网络,作为学习练习构建。
入门
use artha::{
NeuralNetwork,
neural_net::{
normaize_val,
mean_loss,
find_max,
}
};
use ndarray::array;
fn main() {
let mut xs = array![[2.,9.],[1.,5.],[3.,6.]];
normaize_val(find_max(&xs), &mut xs);
let mut ys = array![[92.], [86.], [89.]];
normaize_val(vec![100.], &mut ys);
let mut nn = NeuralNetwork::new(2,1,vec![3]);
let predicted = nn.train(&xs, &ys, 10000);
let loss = mean_loss(&ys, &predicted);
use artha::logln;
logln!("Input: ", xs);
logln!("Actual Output: ", ys);
logln!("Predicted Output: ", predicted);
logln!("Loss: ", loss);
}
这个程序是将 https://dev.to/shamdasani/build-a-flexible-neural-network-with-backpropagation-in-python 直接翻译成 rust。
还可以查看 3Blue1Browns 关于神经网络的出色系列 https://www.youtube.com/playlist?list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi
我发现这个网络比教程慢得多。也许 ndarray 不如 numpy 快,或者也许我的 rust 代码没有优化。我肯定会调查一下。
除了优化之外,我还希望实现一个能够识别手写数字以及其他什么的网络。但到目前为止,这是一个相当不准确的新手版本,我可以在此基础上进行构建。
- 如果您有任何问题或建议,请随时提交问题,或以其他方式与我联系。
- 如果您觉得我的 rust 技能不够,请提供一些有建设性的批评。
依赖项
~3MB
~50K SLoC