3 个版本
0.6.2 | 2023年10月5日 |
---|---|
0.6.1 | 2023年10月4日 |
0.5.0 |
|
0.4.9 |
|
0.2.7 |
|
在 机器学习 分类中排名第 206
每月下载量 64
38KB
555 行
Rust-Neural-Network
这是一个用 Rust 编写的神经网络库的开端,旨在提供一个灵活高效的平台来构建和训练简单的神经网络。当前实现要求所有内容都在堆上分配,并以块的形式计算。一旦 Rust 在常量表达式中实现更好的泛型,这将得到改变。
我不希望与夜间编译器打交道。
为什么是这个库?
由于所有内容都是用 Rust 编写的,并且不使用任何外部依赖项,因此它可以轻松地集成到任何嵌入式项目中,而无需担心速度或对其他语言的依赖(这是我特定的用例)
项目状态
这个库仍处于早期开发阶段,当前版本处于测试阶段,一旦实现基于栈的分配,将升级到 1.0.0 版本。欢迎贡献和反馈,但请注意,随着库的成熟,其内部结构可能会进行重大更改。
特性
- 训练:该库允许使用反向传播和梯度下降进行低级网络训练。
- 激活函数(Sigmoid、Tanh、ArcTanh、Relu、LeakyRelu、SoftMax、SoftPlus 等)。
- 模型序列化
- 基于栈的分配(即将推出)
- GPU 加速(可能会在将来实现)
但是,以下是一个使用该库创建简单神经网络并对其进行单次迭代的示例
use fast_neural_network::{activation::*, neural_network::*};
use ndarray::*;
fn main() {
let mut network = Network::new(2, 1, ActivationType::LeakyRelu, 0.01); // Create a new network with 2 inputs, 1 output, a LeakyRelu activation function, and a learning rate of 0.01
network.add_hidden_layer_with_size(2); // Add a hidden layer with 2 neurons
network.compile(); // Compile the network to prepare it for training
// (will be done automatically during training)
// The API is exposed so that the user can compile
// the network on a different thread before training if they want to
// Let's create a dataset to represent the XOR function
let mut dataset: Vec<(ndarray::Array1<f64>, ndarray::Array1<f64>)> = Vec::new();
dataset.push((array!(0., 0.), array!(0.)));
dataset.push((array!(1., 0.), array!(1.)));
dataset.push((array!(0., 1.), array!(1.)));
dataset.push((array!(1., 1.), array!(0.)));
network.train(&dataset, 20_000, 1_000); // train the network for 20,000 epochs with a decay_time of 1,000 epochs
let mut res;
// Let's check the result
for i in 0..dataset.len() {
res = network.forward(&dataset[i].0);
let d = &dataset[i];
println!(
"for [{:.3}, {:.3}], [{:.3}] -> [{:.3}]",
d.0[0], d.0[1], d.1[0], res
);
}
network.save("network.json"); // Save the model as a json to a file
// Load the model from a json file using the below line
// let mut loaded_network = Network::load("network.json");
}
预期输出
速度
该库的重点是多线程性能。该库旨在尽可能快,我已经尽力优化代码以获得性能。该库仍处于早期阶段,因此仍有改进的空间,但我已尽力使其尽可能快。我只希望 Rust 能够像 C++ 一样实现更好的常量表达式泛型。
矩阵并行化目前尚未实现,但一旦 Rust 实现了更好的泛型,将会实现。
贡献
鼓励贡献!如果您有兴趣添加新功能、提高性能、修复错误或增强文档,我将非常感激您的帮助。只需提交一个拉取请求,我会查看。
路线图
以下功能可能将在未来的版本中实现
- 支持更多的激活函数
- 使用CUDA或其他类似技术的GPU加速(可能只是着色器,但不知道,这似乎很难)
- 增强模型评估工具(可能还会提供与之配套的GUI。如果我自己编写,它将是Raylib)
依赖项
约5-13MB
约155K SLoC