#machine-learning #deep-learning #tensor

rai-core

Rust中的ML框架,具有人性化的API

10个重大发布

0.11.0 2024年5月14日
0.10.0 2024年3月26日
0.9.0 2024年2月24日
0.2.0 2023年12月28日

#910 in 机器学习

Download history 191/week @ 2024-05-12 37/week @ 2024-05-19 24/week @ 2024-05-26 16/week @ 2024-06-02 15/week @ 2024-06-09 13/week @ 2024-06-16 9/week @ 2024-06-23 2/week @ 2024-06-30 43/week @ 2024-07-07 16/week @ 2024-07-14 2/week @ 2024-07-21 10/week @ 2024-07-28

71 每月下载量
5 个crate(3 个直接)中使用

MIT/Apache

295KB
9K SLoC

RAI

Rust Docs Status Latest Version Discord

在Rust中具有人性化API的ML框架。具有类似JAX的懒惰计算和可组合转换。

安装

cargo add rai

代码片段

函数转换(jvp、vjp、grad、value_and_grad)

use rai::{grad, Cpu, Tensor, F32};

fn f(x: &Tensor) -> Tensor {
    x.sin()
}

fn main() {
    let grad_fn = grad(grad(f));
    let x = &Tensor::ones([1], F32, &Cpu);
    let grad = grad_fn(x);
    println!("{}", grad.dot_graph());
    println!("{}", grad);
}

NN模块、优化器和损失函数

fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}

fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(&params);
    model.update_params(&mut params);
}

示例

  • 线性回归
    • cargorun --bin线性回归 --release
  • mnist
    • cargorun --binmnist --release
    • cargorun --binmnist --release --功能=cuda
  • mnist-cnn
    • cargorun --binmnist-cnn --release
    • cargorun --binmnist-cnn --release --功能=cuda
  • phi2
    • cargorun --binphi2 --release
    • cargorun --binphi2 --release --功能=cuda
  • phi3
    • cargorun --binphi3 --release
    • cargorun --binphi3 --release --功能=cuda
  • qwen2
    • cargorun --binqwen2 --release
    • cargorun --binqwen2 --release --功能=cuda
  • gemma
    • https://hugging-face.cn/google/gemma-2b中接受许可协议
    • pip安装huggingface_hub
    • 登录hf huggingface-cli login
    • cargorun --bingemma --release
    • cargorun --bingemma --release --功能=cuda
  • vit
    • cargorun --binvit --release
    • cargorun --binvit --release --功能=cuda

许可证

本项目许可协议为以下之一

任选其一。

依赖项

~2–15MB
~218K SLoC