2 个版本
0.1.1 | 2023 年 3 月 15 日 |
---|---|
0.1.0 | 2023 年 3 月 7 日 |
360 在 机器学习 中排名
2MB
280 行
rustygrad
受 Andrej Karpathy 的 micrograd 启发。此存储库使用 Rust 实现了一个微型 Autograd 引擎
- 具有友好的 API
- 易于理解的实现
- 代码量最少
引擎和神经网络分别用约 150 和 100 行代码实现(与 Andrej 的 100 和 50 行相比!)大约是两倍,但速度也快两倍!
示例用法
值 API
Value
可用于构建任意的 DAG 神经网络(有向无环图)。一种思考方式是,它可以用来模拟常见的数学表达式。例如,
$$ g = [(a + b) * (c + d)] ^ 2 $$
可以翻译成如下神经网络
use rustygrad::Value;
fn main() {
let a = Value::from(1.0);
let b = Value::from(2.0);
let c = Value::from(3.0);
let d = Value::from(4.0);
let g = ((&a + &b) * (&c + &d)).pow(2.0);
g.backwards(); // compute gradients
}
以下是更复杂的示例(来自 micrograd),旨在展示大多数支持的 Value 操作 // 以及它们的 Python micrograd 版本
fn main() {
// a = Value(-4.0)
// b = Value(2.0)
let a = Value::from(-4.0);
let b = Value::from(2.0);
// c = a + b
// d = a * b + b**3
let mut c = &a + &b;
let mut d = &a * &b + &b.pow(3.0);
// c += c + 1
// c += 1 + c + (-a)
// d += d * 2 + (b + a).relu()
// d += 3 * d + (b - a).relu()
c += &c + 1.0;
c += 1.0 + &c + (-&a);
d += &d * 2.0 + (&b + &a).relu();
d += 3.0 * &d + (&b - &a).relu();
// e = c - d
// f = e**2
// g = f / 2.0
// g += 10.0 / f
let e = &c - &d;
let f = e.pow(2.0);
let mut g = &f / 2.0;
g += 10.0 / &f;
// print(f'{g.data:.4f}') # prints 24.7041, the outcome of this forward pass
println!("{:.4}", g.borrow().data); // 24.7041
// g.backward()
// print(f'{a.grad:.4f}') # prints 138.8338, i.e. the numerical value of dg/da
// print(f'{b.grad:.4f}') # prints 645.5773, i.e. the numerical value of dg/db
g.backward();
println!("{:.4}", a.borrow().grad); // 138.8338
println!("{:.4}", b.borrow().grad); // 645.5773
}
cargo run --example engine
神经元和多层感知器 API
该库还公开了 Neuron
和多层感知器 MLP
use rustygrad::{Neuron, MLP};
fn main() {
// Create a Neuron
// With input size of 2
// With ReLu layer (true)
let neuron = Neuron::new(2, true);
// Output node
let g = &neuron.forward(&vec![Value::from(7.0)]);
// Create a 2x2x1 MLP net:
// Input layer of size 2
// Hidden layer of size 2
// Ouput layer of size 1
let model = MLP::new(2, vec![2, 1]);
// Some input vector of size 2
let x = vec![Value::from(7.0), Value::from(8.0)];
// Output Value node
let g = &model.forward(x)[0];
}
cargo run --example graphviz
神经元
MLP
训练神经网络
mlp.rs
文件在玩具 make_moons.csv
数据集上训练了一个 MLP 二元分类器(2 个 16 节点的隐藏层)。由于 Rust 中的绘图比较困难,目前这里是一个学习解空间的 ascii 表示
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
. . . . . . . . . . . . . . . . . . * * * * * . . . . . . . . . . . . . . . . *
. . . . . . . . . . . . . . . . . * * * * * * * . . . . . . . . . . . . . . * *
. . . . . . . . . . . . . . . . . * * * * * * * * . . . . . . . . . . . . * * *
. . . . . . . . . . . . . . . . . * * * * * * * * * . . . . . . . . . . * * * *
. . . . . . . . . . . . . . . . . * * * * * * * * * * . . . . . . . . . * * * *
. . . . . . . . . . . . . . . . * * * * * * * * * * * . . . . . . . . * * * * *
. . . . . . . . . . . . . . . . * * * * * * * * * * * * . . . . . . * * * * * *
. . . . . . . . . . . . . . . . * * * * * * * * * * * * * . . . . * * * * * * *
. . . . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
. . . . . . . . * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
cargo run --example mlp
运行测试
cargo test