3个版本
0.1.2 | 2021年8月2日 |
---|---|
0.1.1 | 2021年8月2日 |
0.1.0 | 2021年8月2日 |
#39 in #自动微分
7KB
81 行
reverse
Rust中的逆模式自动微分。
要在您的crate中使用它,请在Cargo.toml
中添加以下内容:
[dependencies]
reverse = "0.1"
示例
use reverse::*;
fn main() {
let graph = Graph::new();
let a = graph.add_var(2.5);
let b = graph.add_var(14.);
let c = (a.sin().powi(2) + b.ln() * 3.) - 5.;
let gradients = c.grad();
assert_eq!(gradients.wrt(&a), (2. * 2.5).sin());
assert_eq!(gradients.wrt(&b), 3. / 14.);
}
可微分函数
有一个可选的diff
功能,它激活一个宏,将函数转换为正确类型,以便它们是可微分的。也就是说,对f64
进行操作的函数可以在不更改的情况下用于可微分变量,而无需指定(不简单)正确的类型。
要使用它,请在Cargo.toml
中添加以下内容:
reverse = { version = "0.1", features = ["diff"] }
函数必须具有以下类型 Fn(&[f64], &[&[f64]]) -> f64
,其中第一个参数包含可微分参数,第二个参数包含任意数据数组。
示例
以下是一个示例,说明该功能允许您做什么
use reverse::*;
fn main() {
let graph = Graph::new();
let a = graph.add_var(5.);
let b = graph.add_var(2.);
// you can track gradients through the function as usual!
let res = addmul(&[a, b], &[&[4.]]);
let grad = res.grad();
assert_eq!(grad.wrt(&a), 1.);
assert_eq!(grad.wrt(&b), 4.);
}
// function must have these argument types but can be arbitrarily complex
#[differentiable]
fn addmul(params: &[f64], data: &[&[f64]]) -> f64 {
params[0] + data[0][0] * params[1]
}
依赖项
~1.5MB
~35K SLoC