2 个版本
0.3.1 | 2019 年 6 月 5 日 |
---|---|
0.3.0 | 2019 年 6 月 5 日 |
#521 in 机器学习
34KB
807 行
R.A.I.L: Rust 人工智能库
RAIL 设计成一个易于创建和训练神经网络的库,类似于 Keras API。它旨在快速且易于使用。
依赖项
RAIL 依赖于 arrayfire-rust,所以在使用 RAIL 之前,请确保已经安装了 arrayfire。
一个简单的 XOR 问题
使用 Mold 解决 XOR 问题非常简单!只需将该包添加到您的 Cargo.toml 文件中
rail = { git = "https://github.com/nlsnightmare/rail" }
然后将其添加到您的代码中
use rail::model::Model;
use rail::layers::dense::Dense;
use rail::layers::activations::Activation;
pub fn main() {
let mut model = Model::new()
.learning_rate(0.01)
.input_size(2)
.layer(Dense::new(2).activation(Activation::Tanh))
.layer(Dense::new(1).activation(Activation::Tanh))
.build(true)
.unwrap();
let tranining_data = vec![
(vec![0., 0.], vec![0.]),
(vec![0., 1.], vec![1.]),
(vec![1., 0.], vec![1.]),
(vec![1., 1.], vec![0.]),
];
// Train with a batch of 2 for 4000 epochs
model.train(&tranining_data, 2, 4000);
println!("[0, 0] -> {}", model.predict(vec![0., 0.])[0]); // should be close to 0
println!("[0, 1] -> {}", model.predict(vec![0., 1.])[0]); // should be close to 1
println!("[1, 0] -> {}", model.predict(vec![1., 0.])[0]); // should be close to 1
println!("[1, 1] -> {}", model.predict(vec![1., 1.])[0]); // should be close to 0
}
计划
截至目前,RAIL 处于非常早期的阶段,正在积极开发中。API 将会发生很大变化。
目前,仅支持稠密(即全连接)层,并且批量 SGD 是训练网络的唯一方式。但是,计划支持
- 卷积层
- RNN 单元
- LSTM 单元
- 遗传交叉
- ADAM 优化器
- 更多激活函数
- 更多误差函数
- 文档
依赖项
~2.5MB
~40K SLoC