#自动微分 #梯度 #微分 #自动 #机器学习

radient

Radient是一个为自动微分设计的Rust库。它利用计算图的力量来执行梯度计算的前向和反向传递。

2个不稳定版本

0.2.0 2023年11月9日
0.1.0 2023年11月7日

#298 in 机器学习

MIT/Apache

30KB
720

Radient

Radient是一个为自动微分设计的Rust库。它利用计算图的力量来执行梯度计算的前向和反向传递。

功能

  • 计算图实现。
  • 用于梯度计算的前向和反向传播。
  • 支持各种操作,如指数、对数、幂和三角函数。

示例

示例1:符号的基本操作

use radient::prelude::*;

// Example with symbol : ln(x + y) * tanh(x - y)^2
fn main() {
    let mut graph = Graph::default();

    let x = graph.var(2.0);
    let y = graph.var(1.0);
    let x_sym = Expr::Symbol(x);
    let y_sym = Expr::Symbol(y);
    let expr_sym = (&x_sym + &y_sym).ln() * (&x_sym - &y_sym).tanh().powi(2);

    graph.compile(expr_sym);

    let result = graph.forward();
    println!("Result: {}", result);

    graph.backward();
    let gradient_x = graph.get_gradient(x);
    println!("Gradient x: {}", gradient_x);
}

示例2:获取函数的梯度

对于梯度,您有两种选择

  1. gradient:简洁但相对较慢(但不是太慢)
  2. gradient_cached:快速但稍微啰嗦

2.1: gradient

use radient::prelude::*;

fn main() {
    let value = vec![2f64, 1f64];
    // No cached gradient - concise but relatively slow
    let (result, gradient) = gradient(f, &value);
    println!("result: {}, gradient: {:?}", result, gradient);
}

fn f(x_vec: &[Expr]) -> Expr {
    let x = &x_vec[0];
    let y = &x_vec[1];

    (x.powi(2) + y.powi(2)).sqrt()
}

2.2: gradient_cached

use radient::prelude::*;

fn main() {
    // Compile the graph
    let mut graph = Graph::default();
    graph.touch_vars(2);
    let symbols = graph.get_symbols();
    let expr = f(&symbols);
    graph.compile(expr);

    // Compute
    let value = vec![2f64, 1f64];
    let (result, grads) = gradient_cached(&mut graph, &value);

    println!("result: {}, gradient: {:?}", result, grads);
}

fn f(x_vec: &[Expr]) -> Expr {
    let x = &x_vec[0];
    let y = &x_vec[1];

    (x.powi(2) + y.powi(2)).sqrt()
}

入门指南

要在项目中使用Radient,请将以下内容添加到您的 Cargo.toml

[dependencies]
radient = "0.2"

然后,在您的Rust文件中添加以下代码

use radient::*;

许可证

Radient采用Apache2.0或MIT许可证 - 请参阅LICENSE-APACHE & LICENSE-MIT 文件以获取详细信息。

依赖项

~1.5MB
~35K SLoC