2个不稳定版本
0.2.0 | 2022年12月29日 |
---|---|
0.1.0 | 2022年9月19日 |
#38 in #自动微分
用于 aegir
11KB
124 行
aegir
概览
在Rust中进行强类型、编译时自动微分。
aegir
是一个实验性的自动微分框架,旨在利用Rust强大的类型系统,并尽可能避免运行时。采用的方法类似于在C++编写的线性代数库中常用的表达式模板。
主要特性
- 内置算术、线性代数、三角函数和特殊运算符。
- 无限可微:Jacobian,Hessian等...
- 自定义运算符扩展DSL。
- 解耦/通用张量类型。
安装
[dependencies]
aegir = "1.0"
示例
#[macro_use]
extern crate aegir;
extern crate rand;
use aegir::{Differentiable, Function, Identifier, Node, ids::{X, Y, W}};
db!(Database { x: X, y: Y, w: W });
fn main() {
let mut weights = [0.0, 0.0];
let x = X.into_var();
let y = Y.into_var();
let w = W.into_var();
let model = x.dot(w);
// Using standard method calls...
let sse = model.sub(y).squared();
let adj = sse.adjoint(W);
// ...or using aegir! macro
let sse = aegir!((model - y) ^ 2);
let adj = sse.adjoint(W);
for _ in 0..100000 {
let [x1, x2] = [rand::random::<f64>(), rand::random::<f64>()];
let g = adj.evaluate(Database {
// Independent variables:
x: [x1, x2],
// Dependent variable:
y: x1 * 2.0 - x2 * 4.0,
// Model weights:
w: &weights,
}).unwrap();
weights[0] -= 0.01 * g[0][0];
weights[1] -= 0.01 * g[0][1];
}
println!("{:?}", weights.to_vec());
}
依赖
~1.5MB
~35K SLoC