2个不稳定版本
0.2.0 | 2023年11月9日 |
---|---|
0.1.0 | 2023年11月7日 |
#298 in 机器学习
30KB
720 行
Radient
Radient是一个为自动微分设计的Rust库。它利用计算图的力量来执行梯度计算的前向和反向传递。
功能
- 计算图实现。
- 用于梯度计算的前向和反向传播。
- 支持各种操作,如指数、对数、幂和三角函数。
示例
示例1:符号的基本操作
use radient::prelude::*;
// Example with symbol : ln(x + y) * tanh(x - y)^2
fn main() {
let mut graph = Graph::default();
let x = graph.var(2.0);
let y = graph.var(1.0);
let x_sym = Expr::Symbol(x);
let y_sym = Expr::Symbol(y);
let expr_sym = (&x_sym + &y_sym).ln() * (&x_sym - &y_sym).tanh().powi(2);
graph.compile(expr_sym);
let result = graph.forward();
println!("Result: {}", result);
graph.backward();
let gradient_x = graph.get_gradient(x);
println!("Gradient x: {}", gradient_x);
}
示例2:获取函数的梯度
对于梯度,您有两种选择
gradient
:简洁但相对较慢(但不是太慢)gradient_cached
:快速但稍微啰嗦
2.1: gradient
use radient::prelude::*;
fn main() {
let value = vec![2f64, 1f64];
// No cached gradient - concise but relatively slow
let (result, gradient) = gradient(f, &value);
println!("result: {}, gradient: {:?}", result, gradient);
}
fn f(x_vec: &[Expr]) -> Expr {
let x = &x_vec[0];
let y = &x_vec[1];
(x.powi(2) + y.powi(2)).sqrt()
}
2.2: gradient_cached
use radient::prelude::*;
fn main() {
// Compile the graph
let mut graph = Graph::default();
graph.touch_vars(2);
let symbols = graph.get_symbols();
let expr = f(&symbols);
graph.compile(expr);
// Compute
let value = vec![2f64, 1f64];
let (result, grads) = gradient_cached(&mut graph, &value);
println!("result: {}, gradient: {:?}", result, grads);
}
fn f(x_vec: &[Expr]) -> Expr {
let x = &x_vec[0];
let y = &x_vec[1];
(x.powi(2) + y.powi(2)).sqrt()
}
入门指南
要在项目中使用Radient,请将以下内容添加到您的 Cargo.toml
[dependencies]
radient = "0.2"
然后,在您的Rust文件中添加以下代码
use radient::*;
许可证
Radient采用Apache2.0或MIT许可证 - 请参阅LICENSE-APACHE & LICENSE-MIT 文件以获取详细信息。
依赖项
~1.5MB
~35K SLoC