2 个版本

0.1.1 2020 年 12 月 26 日
0.1.0 2020 年 12 月 21 日

机器学习 中排名 #485

MIT 许可证

24KB
537

neuralnetwork

使用 Rust 从头编写的小型神经网络库。

XOR 示例

创建一个具有 2 个输入节点、2 个隐藏层(每层 2 个节点)和 1 个输出节点的网络,使用 sigmoid 作为激活函数,学习率设置为 1。

使用随机梯度下降和均方误差作为损失函数进行 10k 个周期的训练。

use neuralnetwork::neuralnetwork::NeuralNetwork;
use neuralnetwork::{current_millis, parse_csv};

fn main() {
    let mut nn = NeuralNetwork::new(2, vec![2, 2], 1, 1.0, "sigmoid");
    let (inputs, outputs) = parse_csv("xor.csv", 2, 1);
    let start_time = current_millis();
    let epochs = 10_000;
    for _ in 0..epochs {
        for i in 0..inputs.len() {
            nn.train(&inputs[i], &outputs[i]);
        }
    }
    let end_time = (current_millis() - start_time) as f32 / 1000 as f32;
    println!("Training {} epochs took {}s", epochs, end_time);
    for i in 0..inputs.len() {
        println!(
            "Input {} {} Prediction {:.8} Goal {:.}",
            inputs[i][0][0],
            inputs[i][1][0],
            nn.predict(&inputs[i])[0][0],
            outputs[i][0][0]
        );
    }
}

xor.csv 的内容

0,0,0
0,1,1
1,0,1
1,1,0

MNIST 示例

创建一个具有 784 个输入节点、2 个隐藏层(每层 8 个节点)和 10 个输出节点的网络,使用 sigmoid 作为激活函数,学习率设置为 0.1。

use neuralnetwork::neuralnetwork::NeuralNetwork;
use neuralnetwork::{get_accuracy, train_on_dataset};

fn main() {
    let mut nn = NeuralNetwork::new(784, vec![8, 8], 10, 0.1, "sigmoid");
    train_on_dataset(
        &mut nn,
        "mnist_train.csv",
        10,
    );
    print!(
        "Accuarcy: {}%\n",
        get_accuracy(&nn, "mnist_test.csv") * 100.0
    );
}

使用的数据集是 MNIST 的修改版,其中每行的前 784 个值是缩放到 [0,1] 范围的输入,最后 10 个代表输出,使用 one-hot 编码。

train_on_dataset 接收网络、数据集路径和周期数作为输入,并通过随机梯度下降训练网络。

get_accuracy 计算测试集上的准确率,它假定 one-hot 编码并检查具有最高值的节点是否正确。

训练 10 个周期大约需要 30 秒,准确率 >90%。

无运行时依赖项