14个版本 (8个重大更新)

使用旧的Rust 2015

0.9.1 2018年6月2日
0.8.1 2018年5月27日
0.7.2 2018年1月21日
0.2.0 2017年12月26日

#627 in 机器学习

Download history 23/week @ 2024-03-14 32/week @ 2024-03-28 14/week @ 2024-04-04

每月61次下载
用于 sbr

MIT 许可证

180KB
4.5K SLoC

Wyrm

Crates.io badge Docs.rs badge Build Status

一个反向模式、定义时运行、低开销的自动微分库。

特性

通过任意定义时运行的计算图执行反向传播,强调在CPU上对稀疏、小型模型的低开销估计。

亮点

  1. 低开销。
  2. 内置对稀疏梯度的支持。
  3. 定义时运行。
  4. 简单的Hogwild-style并行化,与可用的CPU核心数量成线性比例扩展。

快速入门

以下定义了一个一元线性回归模型,然后通过它进行反向传播。

let slope = ParameterNode::new(random_matrix(1, 1));
let intercept = ParameterNode::new(random_matrix(1, 1));

let x = InputNode::new(random_matrix(1, 1));
let y = InputNode::new(random_matrix(1, 1));

let y_hat = slope.clone() * x.clone() + intercept.clone();
let mut loss = (y.clone() - y_hat).square();

为了优化参数,创建一个优化器对象,并通过几个学习周期进行学习

let mut optimizer = SGD::new(0.1, vec![slope.clone(), intercept.clone()]);

for _ in 0..num_epochs {
    let x_value: f32 = rand::random();
    let y_value = 3.0 * x_value + 5.0;

    // You can re-use the computation graph
    // by giving the input nodes new values.
    x.set_value(x_value);
    y.set_value(y_value);

    loss.forward();
    loss.backward(1.0);

    optimizer.step(loss.parameters());
}

您可以使用 rayon 来并行拟合您的模型,首先创建一组共享参数,然后构建每个线程的模型副本

let slope_param = Arc::new(HogwildParameter::new(random_matrix(1, 1)));
let intercept_param = Arc::new(HogwildParameter::new(random_matrix(1, 1)));
let num_epochs = 10;

(0..rayon::current_num_threads())
    .into_par_iter()
       .for_each(|_| {
           let slope = ParameterNode::shared(slope_param.clone());
           let intercept = ParameterNode::shared(intercept_param.clone());
           let x = InputNode::new(random_matrix(1, 1));
           let y = InputNode::new(random_matrix(1, 1));
           let y_hat = slope.clone() * x.clone() + intercept.clone();
           let mut loss = (y.clone() - y_hat).square();

           let mut optimizer = SGD::new(0.1, vec![slope.clone(), intercept.clone()]);

           for _ in 0..num_epochs {
               let x_value: f32 = rand::random();
               let y_value = 3.0 * x_value + 5.0;

               x.set_value(x_value);
               y.set_value(y_value);

               loss.forward();
               loss.backward(1.0);

               optimizer.step(loss.parameters());
           }
       });

BLAS支持

您应该启用BLAS支持以从矩阵乘法密集型工作负载中获得(大大)更好的性能。要做到这一点,请将以下内容添加到您的 Cargo.toml

ndarray = { version = "0.11.0", features = ["blas", "serde-1"] }
blas-src = { version = "0.1.2", default-features = false, features = ["openblas"] }
openblas-src = { version = "0.5.6", default-features = false, features = ["cblas"] }

依赖项

~4.5MB
~91K SLoC