5 个版本

0.1.4 2022 年 10 月 14 日
0.1.3 2022 年 10 月 14 日
0.1.2 2022 年 10 月 14 日
0.1.1 2022 年 10 月 14 日
0.1.0 2022 年 10 月 14 日

#453机器学习

MIT/Apache

2MB
709

神经网络

用 Rust 编写的简单神经网络。

关于

这个使用梯度下降实现的神经网络实现完全是用 Rust 从头编写的。可以指定网络的结构,以及网络的学习率。此外,可以选择许多预定义的数据集之一,例如 XOR 和 CIRCLE 数据集,它们代表了并集平方内的相对函数。以及更复杂的数据集,如 RGB_DONUT,它代表了一个像甜甜圈一样带有彩虹色过渡形状。

下面,你可以看到一个训练过程,其中网络试图学习 RGB_DONUT 数据集的颜色值。

特性

以下特性目前已被实现

  • 优化器
    1. Adam
    2. RMSProp
    3. SGD
  • 损失函数
    1. 二次方
  • 激活函数
    1. sigmoid
    2. ReLU
    1. 密集型
  • 绘图
    1. 在训练期间绘制成本历史
    2. 绘制最终的预测,可以是灰度或 RGB 格式

用法

创建和训练神经网络的流程相当直接

carbon

示例训练过程

下面,你可以看到网络是如何学习的

学习动画

https://user-images.githubusercontent.com/54124311/195410077-7a02b075-0269-4ff2-965f-97f224ab2cf1.mp4

最终结果

RGB_DONUT_SGD_ 2,64,64,64,64,64,3

酷的训练结果

RGB_DONUT

大网络

RGB_DONUT_RMS_PROP_ 2,128,128,128,3 RGB_DONUT_RMS_PROP_ 2,128,128,128,3 _history

小网络

RGB_DONUT_SGD_ 2,8,8,8,3 RGB_DONUT_SGD_ 2,8,8,8,3 _history

XOR 问题

XOR_SGD_ 2,8,8,8,1 XOR_SGD_ 2,8,8,8,1 _history

依赖项

~17MB
~153K SLoC