3 个版本 (破坏性)

0.2.0 2021 年 4 月 27 日
0.1.0 2021 年 4 月 22 日
0.0.0 2021 年 4 月 20 日

#913 in 算法

MIT/Apache

160KB
3.5K SLoC

Build Status gad on crates.io Documentation License License

泛型自动微分 (GAD)

此库通过反向传播(即“autograd”)在 Rust 中提供自动微分。它旨在允许一等用户扩展(例如,使用新的数组类型或新的运算符)并支持多种执行模式,以最小的开销。

目前支持以下执行模式,适用于所有库定义的运算符

  • 一阶微分,
  • 高阶微分,
  • 仅前向评估,和
  • 维度检查。

设计原则

此库的核心是一个基于带子的实现,实现反向模式的自动微分

我们选择优先考虑 Rust 的惯用用法,以便尽可能使此库可重用。

  • 核心微分算法不使用不安全的 Rust 功能或内部可变性(例如 RefCell)。所有可微分的表达式在构建时都会显式地修改带子。(以下,带子变量记为 graphg。)

  • 可失败操作永远不会崩溃,并且总是返回一个 Result 类型。例如,两个数组 xy 的和可以写成 g.add(&x, &y)?

  • 所有结构和值都实现了 SendSync 以支持并发编程。

  • 鼓励泛型编程,以便用户公式可以以最小的开销在不同的执行模式(前向评估、维度检查等)中解释。(请参阅下面的代码示例部分。)

虽然这个库主要是由机器学习应用驱动的,但它的目的是涵盖反向模式自动微分的其他用例。在下面的章节中,我们将展示用户如何定义新的算子并添加新的执行模式,同时保持自动微分的能力。

局限性

  • 目前,算子 +-* 等 的常用语法对于可微值不可用。所有操作都是以 g.op(x1, .. xN)(或通常为可出错操作,格式为 g.op(x1, .. xN)?)的形式进行的函数调用。

  • 由于当前 Rust 借用检查器的一个 限制,表达式不能嵌套:必须将 g.add(&x, &g.mul(&y, &z)?)? 写成 let v = g.mul(&y, &z)?; g.add(&x, &v)?

我们相信,这种情况在未来可以通过 Rust 宏得到改善。或者,库的将来扩展可以定义一个新的可微值类别,该类别包含对公共带的隐式 RefCell 引用,并为这些值提供(隐式可出错、线程不安全)的操作特质。

快速入门

要计算梯度,我们首先使用类型为 Graph1 的新带子提供的操作构建一个表达式 g。连续的代数操作会修改 g 的内部状态,以跟踪所有相关的计算,并启用未来的反向传播传递。

然后,我们调用 g.evaluate_gradients(..) 来运行从所需起始点开始的反向传播算法,并使用初始梯度值 direction

除非使用了一次性优化的变体 g.evaluate_gradients_once(..),否则使用 g.evaluate_gradients(..) 的反向传播不会修改 g。这允许从不同的起始点或使用不同的梯度值运行连续(或并发)的反向传播。

// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values.
let a = g.variable(1f32);
let b = g.variable(2f32);
let c = g.mul(&a, &b)?;
// Compute the derivatives of `c` relative to `a` and `b`
let gradients = g.evaluate_gradients(c.gid()?, 1f32)?;
// Read the `dc/da` component.
assert_eq!(*gradients.get(a.gid()?).unwrap(), 2.0);

因为 Graph1,即 g 的类型,将代数操作作为方法提供,所以我们下面将此类类型称为“代数”。GAD 使用特定的 Rust 特质来表示给定代数支持的运算集合。

使用 Arrayfire 进行计算

库的默认数组操作目前基于Arrayfire,这是一个支持GPU和即时编译的可移植数组库。

use arrayfire as af;
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values using Arrayfire arrays
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
let c = g.mul(&a, &b)?;
// Compute gradient of c
let direction = af::constant(1f32, dims);
let gradients = g.evaluate_gradients_once(c.gid()?, direction)?;

在您的系统上安装arrayfire库后,请确保

  • 在您的构建文件Cargo.toml中选中包功能“arrayfire”(例如:gad = { version = "XX", features = ["arrayfire"]}),

  • 使用适当的AF_PATH环境变量运行cargo(例如,在export AF_PATH=/usr/local之后)。

使用泛型进行前向评估和快速维度检查

上面示例中使用的代数Graph1是库提供的几个“默认”代数中选择之一

  • 我们还提供了一个特殊的代数Eval,用于前向评估,即只运行原始操作和维度检查(无记录,无梯度);

  • 类似地,使用代数Check将检查维度,而不会评估或分配任何数组数据;

  • 最后,通过使用Graph1进行一阶微分,以及GraphN进行高阶微分来获得微分。

鼓励用户以泛型方式编写公式,以便可以选择任何默认代数。

以下示例说明了在数组操作的情况下这种编程风格

use arrayfire as af;

fn get_value<A>(g: &mut A) -> Result<<A as AfAlgebra<f32>>::Value>
where A : AfAlgebra<f32>
{
    let dims = af::Dim4::new(&[4, 3, 1, 1]);
    let a = g.variable(af::randu::<f32>(dims));
    let b = g.variable(af::randu::<f32>(dims));
    g.mul(&a, &b)
}

// Direct evaluation. The result type is a primitive (non-differentiable) value.
let mut g = Eval::default();
let c : af::Array<f32> = get_value(&mut g)?;

// Fast dimension-checking. The result type is a dimension.
let mut g = Check::default();
let d : af::Dim4 = get_value(&mut g)?;
assert_eq!(c.dims(), d);

高阶微分

高阶微分使用代数GraphN计算。在这种情况下,梯度是计算也跟踪的值。

// A new tape supporting differentials of any order.
let mut g = GraphN::new();

// Compute forward values using floats.
let x = g.variable(1.0f32);
let y = g.variable(0.4f32);
// z = x * y^2
let z = {
    let h = g.mul(&x, &y)?;
    g.mul(&h, &y)?
};
// Use short names for gradient ids.
let (x, y, z) = (x.gid()?, y.gid()?, z.gid()?);

// Compute gradient.
let dz = g.constant(1f32);
let dz_d = g.compute_gradients(z, dz)?;
let dz_dx = dz_d.get(x).unwrap();

// Compute some 2nd-order differentials.
let ddz = g.constant(1f32);
let ddz_dxd = g.compute_gradients(dz_dx.gid()?, ddz)?;
let ddz_dxdy = ddz_dxd.get(y).unwrap().data();
assert_eq!(*ddz_dxdy, 0.8); // 2y

// Compute some 3rd-order differentials.
let dddz = g.constant(1f32);
let dddz_dxdyd = g.compute_gradients(ddz_dxd.get(y).unwrap().gid()?, dddz)?;
let dddz_dxdydy = dddz_dxdyd.get(y).unwrap().data();
assert_eq!(*dddz_dxdydy, 2.0);

扩展自动微分

操作和代数

默认代数EvalCheckGraph1GraphN旨在为每种默认操作模式(分别,评估、维度检查、一阶微分和高阶微分)提供可互换的操作集。

默认操作被分组到几个名为*Algebra的特性和由上述四个默认代数之一实现。

  • 特殊特性CoreAlgebra<Data>定义了从底层数据(例如数组)到可微值的映射。特别是,方法fn variable(&mut self, data: &Data) -> Self::Value创建了可微变量x,其梯度值可以在以后通过写入idx.gid()?来引用(假设代数是Graph1GraphN)。

  • 其他特征是通过一个或多个值类型参数化的。例如,ArithAlgebra<Value> 提供了逐点取反、乘法、减法等操作,这些操作在 Value 上进行。

使用多个 *Algebra 特征的动机有两个方面

  • 用户可以定义自己的操作(见下一段)。

  • 某些操作比其他操作更广泛地适用。

以下示例说明了整数上的梯度计算

let mut g = Graph1::new();
let a = g.variable(1i32);
let b = g.variable(2i32);
let c = g.sub(&a, &b)?;
assert_eq!(*c.data(), -1);
let gradients = g.evaluate_gradients_once(c.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 1);
assert_eq!(*gradients.get(b.gid()?).unwrap(), -1);

用户定义的操作

用户可以通过定义自己的 *Algebra 特征,并提供对默认代数 EvalCheckGraph1GraphN 的实现来定义新的可微操作。

在以下示例中,我们定义了一个新的操作 square,该操作在整数和 af-数组上执行,并添加了对一阶导数支持

use arrayfire as af;

pub trait UserAlgebra<Value> {
    fn square(&mut self, v: &Value) -> Result<Value>;
}

impl UserAlgebra<i32> for Eval
{
    fn square(&mut self, v: &i32) -> Result<i32> { Ok(v * v) }
}

impl<T> UserAlgebra<af::Array<T>> for Eval
where
    T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>
{
    fn square(&mut self, v: &af::Array<T>) -> Result<af::Array<T>> { Ok(v * v) }
}

impl<D> UserAlgebra<Value<D>> for Graph1
where
    Eval: CoreAlgebra<D, Value = D>
        + UserAlgebra<D>
        + ArithAlgebra<D>
        + LinkedAlgebra<Value<D>, D>,
    D: HasDims + Clone + 'static + Send + Sync,
    D::Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
    fn square(&mut self, v: &Value<D>) -> Result<Value<D>> {
        let result = self.eval().square(v.data())?;
        let value = self.make_node(result, vec![v.input()], {
            let v = v.clone();
            move |graph, store, gradient| {
                if let Some(id) = v.id() {
                    let c = graph.link(&v);
                    let grad1 = graph.mul(&gradient, c)?;
                    let grad2 = graph.mul(c, &gradient)?;
                    let grad = graph.add(&grad1, &grad2)?;
                    store.add_gradient(graph, id, &grad)?;
                }
                Ok(())
            }
        });
        Ok(value)
    }
}

fn main() -> Result<()> {
  let mut g = Graph1::new();
  let a = g.variable(3i32);
  let b = g.square(&a)?;
  assert_eq!(*b.data(), 9);
  let gradients = g.evaluate_gradients_once(b.gid()?, 1)?;
  assert_eq!(*gradients.get(a.gid()?).unwrap(), 6);
  Ok(())
}

GraphN 的实现将与 Graph1 相同。为了简化,我们省略了维度检查。我们建议读者参考库中的测试文件以获取更完整的示例。

用户定义的代数

用户可以通过实现每个支持的 Data 类型中的 CoreAlgebra<Data, Value=Data> 操作特征子集来定义新的“评估”代数(类似于 Eval)。

可以使用库提供的 Graph 构造将仅评估的代数(类似于 Graph1GraphN)转换为支持微分的代数。

以下示例说明了如何定义一个新的评估代数 SymEval 然后推导其对应 SymGraph1

/// A custom algebra for forward-only symbolic evaluation.
#[derive(Clone, Default)]
struct SymEval;

/// Symbolic expressions of type T.
#[derive(Debug, PartialEq)]
enum Exp_<T> {
    Num(T),
    Neg(Exp<T>),
    Add(Exp<T>, Exp<T>),
    Mul(Exp<T>, Exp<T>),
    // ...
}

type Exp<T> = Arc<Exp_<T>>;

impl<T> CoreAlgebra<Exp<T>> for SymEval {
    type Value = Exp<T>;
    fn variable(&mut self, data: Exp<T>) -> Self::Value {
        data
    }
    fn constant(&mut self, data: Exp<T>) -> Self::Value {
        data
    }
    fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value> {
        Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone())))
    }
}

impl<T> ArithAlgebra<Exp<T>> for SymEval {
    fn neg(&mut self, v: &Exp<T>) -> Exp<T> {
        Arc::new(Exp_::Neg(v.clone()))
    }
    fn sub(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
        let v2 = self.neg(v2);
        Ok(Arc::new(Exp_::Add(v1.clone(), v2)))
    }
    fn mul(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
        Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone())))
    }
    // ...
}

// No dimension checks.
impl<T> HasDims for Exp_<T> {
    type Dims = ();
    fn dims(&self) {}
}

impl<T: std::fmt::Display> std::fmt::Display for Exp_<T> {
    // ...
}

/// Apply `graph` module to Derive an algebra supporting gradients.
type SymGraph1 = Graph<Config1<SymEval>>;

fn main() -> Result<()> {
    let mut g = SymGraph1::new();
    let a = g.variable(Arc::new(Exp_::Num("a")));
    let b = g.variable(Arc::new(Exp_::Num("b")));
    let c = g.mul(&a, &b)?;
    let d = g.mul(&a, &c)?;
    assert_eq!(format!("{}", d.data()), "aab");
    let gradients = g.evaluate_gradients_once(d.gid()?, Arc::new(Exp_::Num("1")))?;
    assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)");
    assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1");
    Ok(())
}

贡献

有关如何帮助的说明,请参阅 CONTRIBUTING 文件。

许可

本项目可以在 Apache 2.0 许可证或 MIT 许可证的条款下使用。

依赖

~3.5–5.5MB
~117K SLoC