2 个版本

0.3.1 2019 年 6 月 5 日
0.3.0 2019 年 6 月 5 日

#521 in 机器学习

自定义许可

34KB
807

R.A.I.L: Rust 人工智能库

RAIL 设计成一个易于创建和训练神经网络的库,类似于 Keras API。它旨在快速且易于使用。

依赖项

RAIL 依赖于 arrayfire-rust,所以在使用 RAIL 之前,请确保已经安装了 arrayfire。

一个简单的 XOR 问题

使用 Mold 解决 XOR 问题非常简单!只需将该包添加到您的 Cargo.toml 文件中

rail = { git = "https://github.com/nlsnightmare/rail" }

然后将其添加到您的代码中

use rail::model::Model;
use rail::layers::dense::Dense;
use rail::layers::activations::Activation;

pub fn main() {
    let mut model = Model::new()
        .learning_rate(0.01)
        .input_size(2)
        .layer(Dense::new(2).activation(Activation::Tanh))
        .layer(Dense::new(1).activation(Activation::Tanh))
        .build(true)
        .unwrap();

    let tranining_data = vec![
        (vec![0., 0.], vec![0.]),
        (vec![0., 1.], vec![1.]),
        (vec![1., 0.], vec![1.]),
        (vec![1., 1.], vec![0.]),
    ];

    // Train with a batch of 2 for 4000 epochs
    model.train(&tranining_data, 2, 4000);

    println!("[0, 0] -> {}", model.predict(vec![0., 0.])[0]); // should be close to 0
    println!("[0, 1] -> {}", model.predict(vec![0., 1.])[0]); // should be close to 1
    println!("[1, 0] -> {}", model.predict(vec![1., 0.])[0]); // should be close to 1
    println!("[1, 1] -> {}", model.predict(vec![1., 1.])[0]); // should be close to 0
}

计划

截至目前,RAIL 处于非常早期的阶段,正在积极开发中。API 将会发生很大变化。
目前,仅支持稠密(即全连接)层,并且批量 SGD 是训练网络的唯一方式。但是,计划支持

  • 卷积层
  • RNN 单元
  • LSTM 单元
  • 遗传交叉
  • ADAM 优化器
  • 更多激活函数
  • 更多误差函数
  • 文档

依赖项

~2.5MB
~40K SLoC