3 个版本 (破坏性)
0.2.0 | 2021 年 4 月 27 日 |
---|---|
0.1.0 | 2021 年 4 月 22 日 |
0.0.0 | 2021 年 4 月 20 日 |
#913 in 算法
160KB
3.5K SLoC
泛型自动微分 (GAD)
此库通过反向传播(即“autograd”)在 Rust 中提供自动微分。它旨在允许一等用户扩展(例如,使用新的数组类型或新的运算符)并支持多种执行模式,以最小的开销。
目前支持以下执行模式,适用于所有库定义的运算符
- 一阶微分,
- 高阶微分,
- 仅前向评估,和
- 维度检查。
设计原则
此库的核心是一个基于带子的实现,实现反向模式的自动微分。
我们选择优先考虑 Rust 的惯用用法,以便尽可能使此库可重用。
-
核心微分算法不使用不安全的 Rust 功能或内部可变性(例如
RefCell
)。所有可微分的表达式在构建时都会显式地修改带子。(以下,带子变量记为graph
或g
。) -
可失败操作永远不会崩溃,并且总是返回一个
Result
类型。例如,两个数组x
和y
的和可以写成g.add(&x, &y)?
。 -
所有结构和值都实现了
Send
和Sync
以支持并发编程。 -
鼓励泛型编程,以便用户公式可以以最小的开销在不同的执行模式(前向评估、维度检查等)中解释。(请参阅下面的代码示例部分。)
虽然这个库主要是由机器学习应用驱动的,但它的目的是涵盖反向模式自动微分的其他用例。在下面的章节中,我们将展示用户如何定义新的算子并添加新的执行模式,同时保持自动微分的能力。
局限性
-
目前,算子
+
、-
、*
等 的常用语法对于可微值不可用。所有操作都是以g.op(x1, .. xN)
(或通常为可出错操作,格式为g.op(x1, .. xN)?
)的形式进行的函数调用。 -
由于当前 Rust 借用检查器的一个 限制,表达式不能嵌套:必须将
g.add(&x, &g.mul(&y, &z)?)?
写成let v = g.mul(&y, &z)?; g.add(&x, &v)?
。
我们相信,这种情况在未来可以通过 Rust 宏得到改善。或者,库的将来扩展可以定义一个新的可微值类别,该类别包含对公共带的隐式 RefCell
引用,并为这些值提供(隐式可出错、线程不安全)的操作特质。
快速入门
要计算梯度,我们首先使用类型为 Graph1
的新带子提供的操作构建一个表达式 g
。连续的代数操作会修改 g
的内部状态,以跟踪所有相关的计算,并启用未来的反向传播传递。
然后,我们调用 g.evaluate_gradients(..)
来运行从所需起始点开始的反向传播算法,并使用初始梯度值 direction
。
除非使用了一次性优化的变体 g.evaluate_gradients_once(..)
,否则使用 g.evaluate_gradients(..)
的反向传播不会修改 g
。这允许从不同的起始点或使用不同的梯度值运行连续(或并发)的反向传播。
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values.
let a = g.variable(1f32);
let b = g.variable(2f32);
let c = g.mul(&a, &b)?;
// Compute the derivatives of `c` relative to `a` and `b`
let gradients = g.evaluate_gradients(c.gid()?, 1f32)?;
// Read the `dc/da` component.
assert_eq!(*gradients.get(a.gid()?).unwrap(), 2.0);
因为 Graph1
,即 g
的类型,将代数操作作为方法提供,所以我们下面将此类类型称为“代数”。GAD 使用特定的 Rust 特质来表示给定代数支持的运算集合。
使用 Arrayfire 进行计算
库的默认数组操作目前基于Arrayfire,这是一个支持GPU和即时编译的可移植数组库。
use arrayfire as af;
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values using Arrayfire arrays
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
let c = g.mul(&a, &b)?;
// Compute gradient of c
let direction = af::constant(1f32, dims);
let gradients = g.evaluate_gradients_once(c.gid()?, direction)?;
在您的系统上安装arrayfire库后,请确保
-
在您的构建文件
Cargo.toml
中选中包功能“arrayfire”(例如:gad = { version = "XX", features = ["arrayfire"]}
), -
使用适当的
AF_PATH
环境变量运行cargo
(例如,在export AF_PATH=/usr/local
之后)。
使用泛型进行前向评估和快速维度检查
上面示例中使用的代数Graph1
是库提供的几个“默认”代数中选择之一
-
我们还提供了一个特殊的代数
Eval
,用于前向评估,即只运行原始操作和维度检查(无记录,无梯度); -
类似地,使用代数
Check
将检查维度,而不会评估或分配任何数组数据; -
最后,通过使用
Graph1
进行一阶微分,以及GraphN
进行高阶微分来获得微分。
鼓励用户以泛型方式编写公式,以便可以选择任何默认代数。
以下示例说明了在数组操作的情况下这种编程风格
use arrayfire as af;
fn get_value<A>(g: &mut A) -> Result<<A as AfAlgebra<f32>>::Value>
where A : AfAlgebra<f32>
{
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
g.mul(&a, &b)
}
// Direct evaluation. The result type is a primitive (non-differentiable) value.
let mut g = Eval::default();
let c : af::Array<f32> = get_value(&mut g)?;
// Fast dimension-checking. The result type is a dimension.
let mut g = Check::default();
let d : af::Dim4 = get_value(&mut g)?;
assert_eq!(c.dims(), d);
高阶微分
高阶微分使用代数GraphN
计算。在这种情况下,梯度是计算也跟踪的值。
// A new tape supporting differentials of any order.
let mut g = GraphN::new();
// Compute forward values using floats.
let x = g.variable(1.0f32);
let y = g.variable(0.4f32);
// z = x * y^2
let z = {
let h = g.mul(&x, &y)?;
g.mul(&h, &y)?
};
// Use short names for gradient ids.
let (x, y, z) = (x.gid()?, y.gid()?, z.gid()?);
// Compute gradient.
let dz = g.constant(1f32);
let dz_d = g.compute_gradients(z, dz)?;
let dz_dx = dz_d.get(x).unwrap();
// Compute some 2nd-order differentials.
let ddz = g.constant(1f32);
let ddz_dxd = g.compute_gradients(dz_dx.gid()?, ddz)?;
let ddz_dxdy = ddz_dxd.get(y).unwrap().data();
assert_eq!(*ddz_dxdy, 0.8); // 2y
// Compute some 3rd-order differentials.
let dddz = g.constant(1f32);
let dddz_dxdyd = g.compute_gradients(ddz_dxd.get(y).unwrap().gid()?, dddz)?;
let dddz_dxdydy = dddz_dxdyd.get(y).unwrap().data();
assert_eq!(*dddz_dxdydy, 2.0);
扩展自动微分
操作和代数
默认代数Eval
、Check
、Graph1
、GraphN
旨在为每种默认操作模式(分别,评估、维度检查、一阶微分和高阶微分)提供可互换的操作集。
默认操作被分组到几个名为*Algebra
的特性和由上述四个默认代数之一实现。
-
特殊特性
CoreAlgebra<Data>
定义了从底层数据(例如数组)到可微值的映射。特别是,方法fn variable(&mut self, data: &Data) -> Self::Value
创建了可微变量x
,其梯度值可以在以后通过写入idx.gid()?
来引用(假设代数是Graph1
或GraphN
)。 -
其他特征是通过一个或多个值类型参数化的。例如,
ArithAlgebra<Value>
提供了逐点取反、乘法、减法等操作,这些操作在Value
上进行。
使用多个 *Algebra
特征的动机有两个方面
-
用户可以定义自己的操作(见下一段)。
-
某些操作比其他操作更广泛地适用。
以下示例说明了整数上的梯度计算
let mut g = Graph1::new();
let a = g.variable(1i32);
let b = g.variable(2i32);
let c = g.sub(&a, &b)?;
assert_eq!(*c.data(), -1);
let gradients = g.evaluate_gradients_once(c.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 1);
assert_eq!(*gradients.get(b.gid()?).unwrap(), -1);
用户定义的操作
用户可以通过定义自己的 *Algebra
特征,并提供对默认代数 Eval
、Check
、Graph1
、GraphN
的实现来定义新的可微操作。
在以下示例中,我们定义了一个新的操作 square
,该操作在整数和 af-数组上执行,并添加了对一阶导数支持
use arrayfire as af;
pub trait UserAlgebra<Value> {
fn square(&mut self, v: &Value) -> Result<Value>;
}
impl UserAlgebra<i32> for Eval
{
fn square(&mut self, v: &i32) -> Result<i32> { Ok(v * v) }
}
impl<T> UserAlgebra<af::Array<T>> for Eval
where
T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>
{
fn square(&mut self, v: &af::Array<T>) -> Result<af::Array<T>> { Ok(v * v) }
}
impl<D> UserAlgebra<Value<D>> for Graph1
where
Eval: CoreAlgebra<D, Value = D>
+ UserAlgebra<D>
+ ArithAlgebra<D>
+ LinkedAlgebra<Value<D>, D>,
D: HasDims + Clone + 'static + Send + Sync,
D::Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
fn square(&mut self, v: &Value<D>) -> Result<Value<D>> {
let result = self.eval().square(v.data())?;
let value = self.make_node(result, vec![v.input()], {
let v = v.clone();
move |graph, store, gradient| {
if let Some(id) = v.id() {
let c = graph.link(&v);
let grad1 = graph.mul(&gradient, c)?;
let grad2 = graph.mul(c, &gradient)?;
let grad = graph.add(&grad1, &grad2)?;
store.add_gradient(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}
}
fn main() -> Result<()> {
let mut g = Graph1::new();
let a = g.variable(3i32);
let b = g.square(&a)?;
assert_eq!(*b.data(), 9);
let gradients = g.evaluate_gradients_once(b.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 6);
Ok(())
}
GraphN
的实现将与 Graph1
相同。为了简化,我们省略了维度检查。我们建议读者参考库中的测试文件以获取更完整的示例。
用户定义的代数
用户可以通过实现每个支持的 Data
类型中的 CoreAlgebra<Data, Value=Data>
操作特征子集来定义新的“评估”代数(类似于 Eval
)。
可以使用库提供的 Graph
构造将仅评估的代数(类似于 Graph1
和 GraphN
)转换为支持微分的代数。
以下示例说明了如何定义一个新的评估代数 SymEval
然后推导其对应 SymGraph1
/// A custom algebra for forward-only symbolic evaluation.
#[derive(Clone, Default)]
struct SymEval;
/// Symbolic expressions of type T.
#[derive(Debug, PartialEq)]
enum Exp_<T> {
Num(T),
Neg(Exp<T>),
Add(Exp<T>, Exp<T>),
Mul(Exp<T>, Exp<T>),
// ...
}
type Exp<T> = Arc<Exp_<T>>;
impl<T> CoreAlgebra<Exp<T>> for SymEval {
type Value = Exp<T>;
fn variable(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn constant(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value> {
Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone())))
}
}
impl<T> ArithAlgebra<Exp<T>> for SymEval {
fn neg(&mut self, v: &Exp<T>) -> Exp<T> {
Arc::new(Exp_::Neg(v.clone()))
}
fn sub(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
let v2 = self.neg(v2);
Ok(Arc::new(Exp_::Add(v1.clone(), v2)))
}
fn mul(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone())))
}
// ...
}
// No dimension checks.
impl<T> HasDims for Exp_<T> {
type Dims = ();
fn dims(&self) {}
}
impl<T: std::fmt::Display> std::fmt::Display for Exp_<T> {
// ...
}
/// Apply `graph` module to Derive an algebra supporting gradients.
type SymGraph1 = Graph<Config1<SymEval>>;
fn main() -> Result<()> {
let mut g = SymGraph1::new();
let a = g.variable(Arc::new(Exp_::Num("a")));
let b = g.variable(Arc::new(Exp_::Num("b")));
let c = g.mul(&a, &b)?;
let d = g.mul(&a, &c)?;
assert_eq!(format!("{}", d.data()), "aab");
let gradients = g.evaluate_gradients_once(d.gid()?, Arc::new(Exp_::Num("1")))?;
assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)");
assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1");
Ok(())
}
贡献
有关如何帮助的说明,请参阅 CONTRIBUTING 文件。
许可
本项目可以在 Apache 2.0 许可证或 MIT 许可证的条款下使用。
依赖
~3.5–5.5MB
~117K SLoC