7 个版本
新 0.2.1 | 2024 年 8 月 19 日 |
---|---|
0.2.0 | 2024 年 3 月 30 日 |
0.1.1 | 2021 年 12 月 12 日 |
0.1.0 | 2021 年 10 月 30 日 |
0.0.1 | 2019 年 11 月 14 日 |
#40 in 机器学习
每月 170 次下载
3MB
12K SLoC
autograph
Rust 的机器学习库。
使用 krnl 实现的 GPGPU 内核。
神经网络
#[derive(Layer, Forward)]
#[autograph(forward(Variable4, Output=Variable2))]
struct LeNet5 {
conv1: Conv2,
relu1: Relu,
pool1: MaxPool2,
conv2: Conv2,
relu2: Relu,
pool2: MaxPool2,
flatten: Flatten,
dense1: Dense,
relu3: Relu,
dense2: Dense,
relu4: Relu,
dense3: Dense,
}
impl LeNet5 {
fn new(device: Device, scalar_type: ScalarType) -> Result<Self> {
let conv1 = Conv2::builder()
.device(device.clone())
.scalar_type(scalar_type)
.inputs(1)
.outputs(6)
.filter([5, 5])
.build()?;
let relu1 = Relu;
let pool1 = MaxPool2::builder().filter([2, 2]).build();
let conv2 = Conv2::builder()
.device(device.clone())
.scalar_type(scalar_type)
.inputs(6)
.outputs(16)
.filter([5, 5])
.build()?;
let relu2 = Relu;
let pool2 = MaxPool2::builder().filter([2, 2]).build();
let flatten = Flatten;
let dense1 = Dense::builder()
.device(device.clone())
.scalar_type(scalar_type)
.inputs(16 * 4 * 4)
.outputs(128)
.build()?;
let relu3 = Relu;
let dense2 = Dense::builder()
.device(device.clone())
.scalar_type(scalar_type)
.inputs(128)
.outputs(84)
.build()?;
let relu4 = Relu;
let dense3 = Dense::builder()
.device(device.clone())
.scalar_type(scalar_type)
.inputs(84)
.outputs(10)
.bias(true)
.build()?;
Ok(Self {
conv1,
relu1,
pool1,
conv2,
relu2,
pool2,
flatten,
dense1,
relu3,
dense2,
relu4,
dense3,
})
}
}
let mut model = LeNet5::new(device.clone(), ScalarType::F32)?;
model.init_parameter_grads()?;
let y = model.forward(x)?;
let loss = y.cross_entropy_loss(t)?;
loss.backward()?;
model.update(learning_rate, &optimizer)?;
参见 神经网络 MNIST 示例。
基准测试
NVIDIA GeForce GTX 1060 with Max-Q Design
LeNet5(训练,批量大小 = 100)
autograph |
tch |
candle |
|
---|---|---|---|
bf16_host |
498.54 ms (✅ 1.00x) |
75.26 ms (🚀 6.62x faster) |
N/A |
f32_host |
8.25 ms (✅ 1.00x) |
3.14 ms (🚀 2.63x faster) |
34.17 ms (❌ 4.14x slower) |
bf16_device |
1.76 ms (✅ 1.00x) |
17.63 ms (❌ 10.02x slower) |
N/A |
f32_device |
1.73 ms (✅ 1.00x) |
1.19 ms (✅ 1.45x faster) |
9.76 ms (❌ 5.64x slower) |
LeNet5(推理,批量大小 = 1,000)
autograph |
tch |
candle |
|
---|---|---|---|
bf16_host |
1.81 s (✅ 1.00x) |
193.60 毫秒 (🚀 9.37倍更快) |
N/A |
f32_host |
15.56 毫秒 (✅ 1.00倍) |
9.46 毫秒 (✅ 1.64倍更快) |
94.23 毫秒 (❌ 6.06倍更慢) |
bf16_device |
4.65 毫秒 (✅ 1.00倍) |
48.63 毫秒 (❌ 10.46倍更慢) |
N/A |
f32_device |
4.65 毫秒 (✅ 1.00倍) |
1.84 毫秒 (🚀 2.52倍更快) |
10.81 毫秒 (❌ 2.33倍更慢) |
查看神经网络基准测试。
许可证
双许可,与Rust项目兼容。
根据Apache License,版本2.0 https://apache.ac.cn/licenses/LICENSE-2.0 或MIT许可证 http://opensource.org/licenses/MIT,由您选择。此文件不得复制、修改或分发,除非符合这些条款。
贡献
除非您明确声明,否则您提交的任何贡献,根据Apache-2.0许可证定义,应与上述双许可一致,不附加任何额外条款或条件。
依赖项
~11–23MB
~390K SLoC