27 个版本
0.5.9 | 2022 年 3 月 28 日 |
---|---|
0.5.5 | 2022 年 2 月 28 日 |
0.3.8 | 2020 年 8 月 16 日 |
0.3.6 | 2020 年 7 月 7 日 |
#163 在 机器学习 中
用于 2 crates
555KB
13K SLoC
简单的机器学习工具集
简介
这是一个基于自动微分的学习库。
特性
- 无类型的张量。
- 支持张量上的变量,并支持反向传播。
- 支持常见运算符,包括卷积。
示例
use tensor_rs::tensor::Tensor;
use auto_diff::rand::RNG;
use auto_diff::var::{Module};
use auto_diff::optim::{SGD, Optimizer};
fn main() {
fn func(input: &Tensor) -> Tensor {
input.matmul(&Tensor::from_vec_f32(&vec![2., 3.], &vec![2, 1])).add(&Tensor::from_vec_f32(&vec![1.], &vec![1]))
}
let N = 100;
let mut rng = RNG::new();
rng.set_seed(123);
let data = rng.normal(&vec![N, 2], 0., 2.);
let label = func(&data);
let mut m = Module::new();
let op1 = m.linear(Some(2), Some(1), true);
let weights = op1.get_values().unwrap();
rng.normal_(&weights[0], 0., 1.);
rng.normal_(&weights[1], 0., 1.);
op1.set_values(&weights);
let op2 = op1.clone();
let block = m.func(
move |x| {
op2.call(x)
}
);
let loss_func = m.mse_loss();
let mut opt = SGD::new(3.);
for i in 0..200 {
let input = m.var_value(data.clone());
let y = block.call(&[&input]);
let loss = loss_func.call(&[&y, &m.var_value(label.clone())]);
println!("index: {}, loss: {}", i, loss.get().get_scale_f32());
loss.backward(-1.);
opt.step2(&block);
}
let weights = op1.get_values().expect("");
println!("{:?}, {:?}", weights[0], weights[1]);
}
依赖
安装 gfortran,使用 openblas-src = "0.9"。
贡献
欢迎任何贡献,请通过创建 pull request 打开一个问题。
依赖
~2.5MB
~57K SLoC