2个不稳定版本

0.2.0 2022年12月29日
0.1.0 2022年9月19日

#38 in #自动微分


用于 aegir

自定义许可证

11KB
124

aegir

Crates.io Build Status Coverage Status

概览

在Rust中进行强类型、编译时自动微分。

aegir是一个实验性的自动微分框架,旨在利用Rust强大的类型系统,并尽可能避免运行时。采用的方法类似于在C++编写的线性代数库中常用的表达式模板。

主要特性

  • 内置算术、线性代数、三角函数和特殊运算符。
  • 无限可微:Jacobian,Hessian等...
  • 自定义运算符扩展DSL。
  • 解耦/通用张量类型。

安装

[dependencies]
aegir = "1.0"

示例

#[macro_use]
extern crate aegir;
extern crate rand;

use aegir::{Differentiable, Function, Identifier, Node, ids::{X, Y, W}};

db!(Database { x: X, y: Y, w: W });

fn main() {
    let mut weights = [0.0, 0.0];

    let x = X.into_var();
    let y = Y.into_var();
    let w = W.into_var();

    let model = x.dot(w);

    // Using standard method calls...
    let sse = model.sub(y).squared();
    let adj = sse.adjoint(W);

    // ...or using aegir! macro
    let sse = aegir!((model - y) ^ 2);
    let adj = sse.adjoint(W);

    for _ in 0..100000 {
        let [x1, x2] = [rand::random::<f64>(), rand::random::<f64>()];

        let g = adj.evaluate(Database {
            // Independent variables:
            x: [x1, x2],

            // Dependent variable:
            y: x1 * 2.0 - x2 * 4.0,

            // Model weights:
            w: &weights,
        }).unwrap();

        weights[0] -= 0.01 * g[0][0];
        weights[1] -= 0.01 * g[0][1];
    }

    println!("{:?}", weights.to_vec());
}

依赖

~1.5MB
~35K SLoC