3 个版本 (破坏性更新)
| 0.11.0 | 2024 年 5 月 14 日 | 
|---|---|
| 0.10.0 | 2024 年 3 月 26 日 | 
| 0.7.0 | 2024 年 1 月 23 日 | 
在 机器学习 中排名 560
每月下载量 180 次
17KB
136 行 代码行
RAI
RAI 是 Rust 中具有人体工程学 API 的 ML 框架。具有类似 JAX 的懒计算和可组合变换。
安装
cargo add rai
代码片段
函数变换(jvp、vjp、grad、value_and_grad)
use rai::{grad, Cpu, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
    x.sin()
}
fn main() {
    let grad_fn = grad(grad(f));
    let x = &Tensor::ones([1], F32, &Cpu);
    let grad = grad_fn(x);
    println!("{}", grad.dot_graph());
    println!("{}", grad);
}
NN 模块、优化器和损失函数
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
    let logits = model.forward(input);
    let loss = softmax_cross_entropy(&logits, labels).mean(..);
    (loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
    optimizer: &mut O,
    model: &M,
    input: &Tensor,
    labels: &Tensor,
) {
    let vg_fn = value_and_grad(loss_fn);
    let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
    let mut params = optimizer.step(&grads);
    eval(¶ms);
    model.update_params(&mut params);
}
示例
- 线性回归
- cargo运行 --bin线性回归 --发布
 
- mnist
- cargo运行 --binmnist --发布
- cargo运行 --binmnist --发布 --功能=cuda
 
- mnist-cnn
- cargo运行 --binmnist-cnn --发布
- cargo运行 --binmnist-cnn --发布 --功能=cuda
 
- phi2
- cargo运行 --binphi2 --发布
- cargo运行 --binphi2 --发布 --功能=cuda
 
- phi3
- cargo运行 --binphi3 --发布
- cargo运行 --binphi3 --发布 --功能=cuda
 
- qwen2
- cargo运行 --binqwen2 --发布
- cargo运行 --binqwen2 --发布 --功能=cuda
 
- gemma
- 在 https://hugging-face.cn/google/gemma-2b 中接受许可协议
- pip install huggingface_hub
- 登录 hf huggingface-cli login
- cargo运行 --bingemma --发布
- cargo运行 --bingemma --发布 --功能=cuda
 
- vit
- cargo运行 --binvit --发布
- cargo运行 --binvit --发布 --功能=cuda
 
许可证
此项目受以下任一许可证的许可:
- Apache 许可证,版本 2.0 (https://apache.ac.cn/licenses/LICENSE-2.0)
- MIT 许可证 (https://open-source.org.cn/licenses/MIT)
任选其一。
依赖关系
~23–35MB
~665K SLoC