#自动微分 #逆模式 #模式 #函数

reverse_differentiable

函数的自动微分

3个版本

0.1.2 2021年8月2日
0.1.1 2021年8月2日
0.1.0 2021年8月2日

#39 in #自动微分

MIT/Apache

7KB
81

reverse

Crates.io Documentation License

Rust中的逆模式自动微分。

要在您的crate中使用它,请在Cargo.toml中添加以下内容:

[dependencies]
reverse = "0.1"

示例

use reverse::*;

fn main() {
  let graph = Graph::new();
  let a = graph.add_var(2.5);
  let b = graph.add_var(14.);
  let c = (a.sin().powi(2) + b.ln() * 3.) - 5.;
  let gradients = c.grad();

  assert_eq!(gradients.wrt(&a), (2. * 2.5).sin());
  assert_eq!(gradients.wrt(&b), 3. / 14.);
}

可微分函数

有一个可选的diff功能,它激活一个宏,将函数转换为正确类型,以便它们是可微分的。也就是说,对f64进行操作的函数可以在不更改的情况下用于可微分变量,而无需指定(不简单)正确的类型。

要使用它,请在Cargo.toml中添加以下内容:

reverse = { version = "0.1", features = ["diff"] }

函数必须具有以下类型 Fn(&[f64], &[&[f64]]) -> f64,其中第一个参数包含可微分参数,第二个参数包含任意数据数组。

示例

以下是一个示例,说明该功能允许您做什么

use reverse::*;

fn main() {
    let graph = Graph::new();
    let a = graph.add_var(5.);
    let b = graph.add_var(2.);

    // you can track gradients through the function as usual!
    let res = addmul(&[a, b], &[&[4.]]);
    let grad = res.grad();

    assert_eq!(grad.wrt(&a), 1.);
    assert_eq!(grad.wrt(&b), 4.);
}

// function must have these argument types but can be arbitrarily complex
#[differentiable]
fn addmul(params: &[f64], data: &[&[f64]]) -> f64 {
    params[0] + data[0][0] * params[1]
}

依赖项

~1.5MB
~35K SLoC