1 个不稳定版本

使用旧的 Rust 2015

0.1.0 2017 年 3 月 27 日

#18#deep

MIT 许可证

28KB
547

scarecrow

Build Status

scarecrow 是一个人工神经网络的简单实现。

示例

这展示了神经网络学习非线性函数的能力,即 XOR。它使用 梯度下降 在真值表中训练。

首先定义输入 X 和目标 T

// Two binary input values, 4 possible combinations
let inputs = vec![0.0, 0.0,
                  0.0, 1.0,
                  1.0, 0.0,
                  1.0, 1.0];
// Four binary output targets, one for each possible input value
let targets = vec![0.0,
                   1.0,
                   1.0,
                   0.0];

然后,通过向列表中添加一些层来构建一个神经网络

let mut layers: LinkedList<Box<WeightedLayer>> = LinkedList::new();
// We start by a hidden "dense" layer of 6 neurons which should
// accept 2 input values.
layers.push_back(Box::new(DenseLayer::random(2, 6)));
// We attach hyperbolic activation functions to the dense layer
layers.push_back(Box::new(HyperbolicLayer { size: 6 }));
// We follow this with a final "dense" layer with a single neuron,
// expecting 6 inputs from the preceeding layer.
layers.push_back(Box::new(DenseLayer::random(6, 1)));
// This will be output neuron so we attach a sigmoid activation function
// to get an output between 0 and 1.
layers.push_back(Box::new(SigmoidLayer { size: 1 }));

由于这是在训练之前,我们预计网络的输出将是完全随机的。这可以通过将输入通过网络来查看

for (x, t) in inputs.chunks(2).zip(targets.chunks(1)) {
    let mut o = x.to_vec();
    for l in layers.iter() {
        o = l.output(&o);
    }
    println!("X: {:?}, Y: {:?}, T: {:?}", x, o, t);
}

网络输出的示例 Y

X: [0, 0], Y: [0.4244223], T: [0]
X: [0, 1], Y: [0.049231697], T: [1]
X: [1, 0], Y: [0.12347225], T: [1]
X: [1, 1], Y: [0.02869209], T: [0]

要训练网络,首先创建一个合适的训练器,然后调用其训练方法

// A trainer which uses stochastic gradient descent. Run for
// 1000 iterations with a learning rate of 0.1.
let trainer = SGDTrainer::new(1000, 0.1);
// Train the network on the given inputs and targets
trainer.train(&mut layers, &inputs, &targets);

现在计算训练后网络的输出

for (x, t) in inputs.chunks(2).zip(targets.chunks(1)) {
    let mut o = x.to_vec();
    for l in layers.iter() {
        o = l.output(&o);
    }
    println!("X: {:?}, Y: {:?}, T: {:?}", x, o, t);
}

最终结果,注意网络输出 Y 与目标 T 非常接近

X: [0, 0], Y: [0.03515992], T: [0]
X: [0, 1], Y: [0.96479124], T: [1]
X: [1, 0], Y: [0.96392107], T: [1]
X: [1, 1], Y: [0.03710678], T: [0]

依赖项

~315–540KB