1 个不稳定版本
0.1.0 | 2021年3月22日 |
---|
#5 in #nn
用于 ritemlcr
32KB
511 行
RiteNN - 另一个RustNN分支
RiteNN 提供了用Rust编写的神经网络库的更新版本。
描述
RustNN 是一个 前馈神经网络 库。该库生成完全连接的多层人工神经网络,通过 反向传播 进行训练。网络使用增量训练模式进行训练。
更新
- 由 @PsiACE
- 更更新。
- 使用flexbuffers代替json,更快更紧凑。
- 让一切满意。
- 由 @Felix
- 将L2正则化和几个激活函数添加到原始包中。此外,还有一些小的改进。
- 可以像学习率一样设置Lambda。隐藏层和输出层的激活函数分别在NN::new的第二个和第三个参数中设置。
XOR示例
此示例创建了一个具有 2
个节点的输入层、一个包含 3
个节点的单个隐藏层以及 1
个节点的输出层的神经网络。然后,该网络在 XOR
函数的示例上进行训练。在调用 train(&examples)
之后调用的所有方法都是可选的,只是用来指定各种选项,以指示网络应该如何训练。当调用 go()
方法时,网络将开始训练给定的示例。请参阅 NN
和 Trainer
结构体的文档以获取更多详细信息。
use ritenn::{NN, HaltCondition, Activation};
// create examples of the XOR function
// the network is trained on tuples of vectors where the first vector
// is the inputs and the second vector is the expected outputs
let examples = [
(vec![0f64, 0f64], vec![0f64]),
(vec![0f64, 1f64], vec![1f64]),
(vec![1f64, 0f64], vec![1f64]),
(vec![1f64, 1f64], vec![0f64]),
];
// create a new neural network by passing a pointer to an array
// that specifies the number of layers and the number of nodes in each layer
// in this case we have an input layer with 2 nodes, one hidden layer
// with 3 nodes and the output layer has 1 node
let mut net = NN::new(&[2, 3, 1], Activation::PELU, Activation::Sigmoid);
// train the network on the examples of the XOR function
// all methods seen here are optional except go() which must be called to begin training
// see the documentation for the Trainer struct for more info on what each method does
net.train(&examples)
.halt_condition( HaltCondition::Epochs(10000) )
.log_interval( Some(100) )
.momentum( 0.1 )
.rate( 0.3 )
.go();
// evaluate the network to see if it learned the XOR function
for &(ref inputs, ref outputs) in examples.iter() {
let results = net.run(inputs);
let (result, key) = (results[0].round(), outputs[0]);
assert!(result == key);
}
致谢
它是 nn 和 Felix-Dommes/RustNN 的分支,但已对代码进行了调整和改进。
许可
本库的许可协议为
- MIT许可证 LICENSE-MIT 或 http://opensource.org/licenses/MIT
- Apache License 2.0 LICENSE-APACHE 或 https://opensource.org/licenses/Apache-2.0
由您选择。
依赖项
~1–1.4MB
~29K SLoC