10 个重大版本发布
0.11.0 | 2024 年 5 月 14 日 |
---|---|
0.10.0 | 2024 年 3 月 26 日 |
0.9.0 | 2024 年 2 月 24 日 |
0.2.0 | 2023 年 12 月 28 日 |
#85 in 机器学习
在 3 个 crate 中使用 (通过 rai)
325KB
10K SLoC
RAI
在 Rust 中具有舒适 API 的 ML 框架。具有类似 JAX 的惰性计算和可组合转换。
安装
cargo add rai
代码片段
函数转换 (jvp, vjp, grad, value_and_grad)
use rai::{grad, Cpu, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
x.sin()
}
fn main() {
let grad_fn = grad(grad(f));
let x = &Tensor::ones([1], F32, &Cpu);
let grad = grad_fn(x);
println!("{}", grad.dot_graph());
println!("{}", grad);
}
NN 模块、优化器和损失函数
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
model: &M,
input: &Tensor,
labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
let logits = model.forward(input);
let loss = softmax_cross_entropy(&logits, labels).mean(..);
(loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
optimizer: &mut O,
model: &M,
input: &Tensor,
labels: &Tensor,
) {
let vg_fn = value_and_grad(loss_fn);
let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
let mut params = optimizer.step(&grads);
eval(¶ms);
model.update_params(&mut params);
}
示例
- 线性回归
cargo运行 --bin线性回归 --发布
- mnist
cargo运行 --binmnist --发布
cargo运行 --binmnist --发布 --功能=cuda
- mnist-cnn
cargo运行 --binmnist-cnn --发布
cargo运行 --binmnist-cnn --发布 --功能=cuda
- phi2
cargo运行 --binphi2 --发布
cargo运行 --binphi2 --发布 --功能=cuda
- phi3
cargo运行 --binphi3 --发布
cargo运行 --binphi3 --发布 --功能=cuda
- qwen2
cargo运行 --binqwen2 --发布
cargo运行 --binqwen2 --发布 --功能=cuda
- gemma
- 在 https://hugging-face.cn/google/gemma-2b 上接受许可协议
pip install huggingface_hub
- 登录到 hf
huggingface-cli login
cargo运行 --bingemma --发布
cargo运行 --bingemma --发布 --功能=cuda
- vit
cargo运行 --binvit --发布
cargo运行 --binvit --发布 --功能=cuda
许可证
此项目受以下其中之一许可:
- Apache 许可证 2.0 版 (LICENSE-APACHE 或 https://apache.ac.cn/licenses/LICENSE-2.0)
- MIT 许可证 (LICENSE-MIT 或 http://opensource.org/licenses/MIT)
由您选择。
依赖项
~5–20MB
~263K SLoC