2 个版本
0.1.1 | 2020 年 12 月 26 日 |
---|---|
0.1.0 | 2020 年 12 月 21 日 |
在 机器学习 中排名 #485
24KB
537 行
neuralnetwork
使用 Rust 从头编写的小型神经网络库。
XOR 示例
创建一个具有 2 个输入节点、2 个隐藏层(每层 2 个节点)和 1 个输出节点的网络,使用 sigmoid 作为激活函数,学习率设置为 1。
使用随机梯度下降和均方误差作为损失函数进行 10k 个周期的训练。
use neuralnetwork::neuralnetwork::NeuralNetwork;
use neuralnetwork::{current_millis, parse_csv};
fn main() {
let mut nn = NeuralNetwork::new(2, vec![2, 2], 1, 1.0, "sigmoid");
let (inputs, outputs) = parse_csv("xor.csv", 2, 1);
let start_time = current_millis();
let epochs = 10_000;
for _ in 0..epochs {
for i in 0..inputs.len() {
nn.train(&inputs[i], &outputs[i]);
}
}
let end_time = (current_millis() - start_time) as f32 / 1000 as f32;
println!("Training {} epochs took {}s", epochs, end_time);
for i in 0..inputs.len() {
println!(
"Input {} {} Prediction {:.8} Goal {:.}",
inputs[i][0][0],
inputs[i][1][0],
nn.predict(&inputs[i])[0][0],
outputs[i][0][0]
);
}
}
xor.csv 的内容
0,0,0
0,1,1
1,0,1
1,1,0
MNIST 示例
创建一个具有 784 个输入节点、2 个隐藏层(每层 8 个节点)和 10 个输出节点的网络,使用 sigmoid 作为激活函数,学习率设置为 0.1。
use neuralnetwork::neuralnetwork::NeuralNetwork;
use neuralnetwork::{get_accuracy, train_on_dataset};
fn main() {
let mut nn = NeuralNetwork::new(784, vec![8, 8], 10, 0.1, "sigmoid");
train_on_dataset(
&mut nn,
"mnist_train.csv",
10,
);
print!(
"Accuarcy: {}%\n",
get_accuracy(&nn, "mnist_test.csv") * 100.0
);
}
使用的数据集是 MNIST 的修改版,其中每行的前 784 个值是缩放到 [0,1] 范围的输入,最后 10 个代表输出,使用 one-hot 编码。
train_on_dataset
接收网络、数据集路径和周期数作为输入,并通过随机梯度下降训练网络。
get_accuracy
计算测试集上的准确率,它假定 one-hot 编码并检查具有最高值的节点是否正确。
训练 10 个周期大约需要 30 秒,准确率 >90%。