27 个版本

0.5.9 2022 年 3 月 28 日
0.5.5 2022 年 2 月 28 日
0.3.8 2020 年 8 月 16 日
0.3.6 2020 年 7 月 7 日

#163机器学习


用于 2 crates

MIT 许可证

555KB
13K SLoC

简单的机器学习工具集

crates.io version License example workflow doc badge

简介

这是一个基于自动微分的学习库。

特性

  • 无类型的张量。
  • 支持张量上的变量,并支持反向传播。
  • 支持常见运算符,包括卷积。

示例

use tensor_rs::tensor::Tensor;
use auto_diff::rand::RNG;
use auto_diff::var::{Module};
use auto_diff::optim::{SGD, Optimizer};

fn main() {

    fn func(input: &Tensor) -> Tensor {
        input.matmul(&Tensor::from_vec_f32(&vec![2., 3.], &vec![2, 1])).add(&Tensor::from_vec_f32(&vec![1.], &vec![1]))
    }

    let N = 100;
    let mut rng = RNG::new();
    rng.set_seed(123);
    let data = rng.normal(&vec![N, 2], 0., 2.);
    let label = func(&data);


    let mut m = Module::new();
    
    let op1 = m.linear(Some(2), Some(1), true);
    let weights = op1.get_values().unwrap();
    rng.normal_(&weights[0], 0., 1.);
    rng.normal_(&weights[1], 0., 1.);
    op1.set_values(&weights);

    let op2 = op1.clone();
    let block = m.func(
        move |x| {
            op2.call(x)
        }
    );
    
    let loss_func = m.mse_loss();
    
    let mut opt = SGD::new(3.);

    for i in 0..200 {
        let input = m.var_value(data.clone());
        
        let y = block.call(&[&input]);
        
        let loss = loss_func.call(&[&y, &m.var_value(label.clone())]);
        println!("index: {}, loss: {}", i, loss.get().get_scale_f32());
        
        loss.backward(-1.);
        opt.step2(&block);

    }

    let weights = op1.get_values().expect("");
    println!("{:?}, {:?}", weights[0], weights[1]);
}

依赖

安装 gfortran,使用 openblas-src = "0.9"。

贡献

欢迎任何贡献,请通过创建 pull request 打开一个问题。

依赖

~2.5MB
~57K SLoC