#神经网络 # #模块化 #深度 #优化 #DNN

Rust_Simple_DNN

一个用于在Rust中创建优化模块化神经网络的crate

6个版本

0.1.5 2023年12月27日
0.1.4 2023年12月25日

#169 in 机器学习

每月 33 次下载

MIT 许可证

77KB
582

Rust DNN

在Rust中轻松创建模块化深度神经网络

进行中

如果有人实际上为这个项目打星,我将添加卷积层、更多激活函数和反卷积层。如果这个项目获得20个星标,我将添加一切

安装

在运行后

cargo add Rust_Simple_DNN

然后你必须在Rust代码中放入这些

use Rust_Simple_DNN::rdnn::layers::*;
use Rust_Simple_DNN::rdnn::*;

当前实现的层

将层视为神经网络的基本构建块。不同的层以不同的方式处理数据。选择正确的层以适应你的情况很重要。(例如:卷积层用于图像处理)

  • 全连接密集层
FC::new(inputSize, outputSize)

这些层在仅进行直接原始数据处理时表现最佳。使用这些层结合激活函数,理论上可以创建任何你想要的数学模型。但这些层的计算量在扩展时会呈指数级增长。


  • 激活函数
Tanh::new(inputSize); //hyperbolic tangent
Relu::new(inputSize); //if activation > 0
Sig::new(inputSize); //sigmoid

将这些函数放在FC、Conv、Deconv或任何点积层之后,以使网络非线性,否则网络将无法在99%的使用场景中工作。

迷你教程

这是如何创建如下所示的神经网络
image-alt-text-check-github-to-see-image

使用此代码创建它

//FC layers are dense layers.
//Sig layers are sigmoid activation
let mut net = Net::new(
        vec![
            FC::new(3, 4), //input 3, output 4
            Sig::new(4), //sigmoid, input 4 output 4

            FC::new(4, 4),
            Sig::new(4), //sigmoid

            FC::new(4, 1),// input 4 output 1
            Sig::new(1), //sigmoid
        ],
        1, //batch size
        0.1, //learning rate
    );
    //"net" is the variable representing your entire network


这是如何在网络中传播数据
net.forward_data(&vec![1.0, 0.0, -69.0]); //returns the output vector

在传播一些数据后,你可以像这样进行反向传播

 net.backward_data(&vec![0.0]); //a vector of what you want the nn to output

网络将自动存储和应用梯度,因此要训练网络,你所需要做的就是反复前向和反向传播你的数据

let mut x = 0;

    while x < 5000 {
        net.forward_data(&vec![1.0, 0.0, 0.0]);
        net.backward_data(&vec![1.0]);

        net.forward_data(&vec![1.0, 1.0, 0.0]);
        net.backward_data(&vec![0.0]);

        net.forward_data(&vec![0.0, 1.0, 0.0]);
        net.backward_data(&vec![1.0]);

        net.forward_data(&vec![0.0, 0.0, 0.0]);
        net.backward_data(&vec![0.0]);
        x += 1;
    }

//at this point its trained (although this dataset is pretty useless lol)

如果不是无谓的复杂,这将像Pytorch一样

依赖项

~310KB