7 个版本

0.2.1 2024 年 8 月 19 日
0.2.0 2024 年 3 月 30 日
0.1.1 2021 年 12 月 12 日
0.1.0 2021 年 10 月 30 日
0.0.1 2019 年 11 月 14 日

#40 in 机器学习

Download history 2/week @ 2024-06-08 1/week @ 2024-06-15 26/week @ 2024-07-06 1/week @ 2024-07-27 169/week @ 2024-08-17

每月 170 次下载

MIT/Apache

3MB
12K SLoC

LicenseBadge DocsBadge build

autograph

Rust 的机器学习库。

使用 krnl 实现的 GPGPU 内核。

  • 主机和设备执行。
  • 张量模拟 ndarray
    • 主机张量可以借用为数组。
  • 张量、模型和优化器可以使用 serde 序列化。
    • 跨平台可移植。
    • 保存和恢复训练进度。
  • 完全可扩展,使用 Rust。

神经网络

#[derive(Layer, Forward)]
#[autograph(forward(Variable4, Output=Variable2))]
struct LeNet5 {
    conv1: Conv2,
    relu1: Relu,
    pool1: MaxPool2,
    conv2: Conv2,
    relu2: Relu,
    pool2: MaxPool2,
    flatten: Flatten,
    dense1: Dense,
    relu3: Relu,
    dense2: Dense,
    relu4: Relu,
    dense3: Dense,
}

impl LeNet5 {
    fn new(device: Device, scalar_type: ScalarType) -> Result<Self> {
        let conv1 = Conv2::builder()
            .device(device.clone())
            .scalar_type(scalar_type)
            .inputs(1)
            .outputs(6)
            .filter([5, 5])
            .build()?;
        let relu1 = Relu;
        let pool1 = MaxPool2::builder().filter([2, 2]).build();
        let conv2 = Conv2::builder()
            .device(device.clone())
            .scalar_type(scalar_type)
            .inputs(6)
            .outputs(16)
            .filter([5, 5])
            .build()?;
        let relu2 = Relu;
        let pool2 = MaxPool2::builder().filter([2, 2]).build();
        let flatten = Flatten;
        let dense1 = Dense::builder()
            .device(device.clone())
            .scalar_type(scalar_type)
            .inputs(16 * 4 * 4)
            .outputs(128)
            .build()?;
        let relu3 = Relu;
        let dense2 = Dense::builder()
            .device(device.clone())
            .scalar_type(scalar_type)
            .inputs(128)
            .outputs(84)
            .build()?;
        let relu4 = Relu;
        let dense3 = Dense::builder()
            .device(device.clone())
            .scalar_type(scalar_type)
            .inputs(84)
            .outputs(10)
            .bias(true)
            .build()?;
        Ok(Self {
            conv1,
            relu1,
            pool1,
            conv2,
            relu2,
            pool2,
            flatten,
            dense1,
            relu3,
            dense2,
            relu4,
            dense3,
        })
    }
}

let mut model = LeNet5::new(device.clone(), ScalarType::F32)?;
model.init_parameter_grads()?;
let y = model.forward(x)?;
let loss = y.cross_entropy_loss(t)?;
loss.backward()?;
model.update(learning_rate, &optimizer)?;

参见 神经网络 MNIST 示例

基准测试

NVIDIA GeForce GTX 1060 with Max-Q Design

LeNet5(训练,批量大小 = 100)

autograph tch candle
bf16_host 498.54 ms (✅ 1.00x) 75.26 ms (🚀 6.62x faster) N/A
f32_host 8.25 ms (✅ 1.00x) 3.14 ms (🚀 2.63x faster) 34.17 ms (❌ 4.14x slower)
bf16_device 1.76 ms (✅ 1.00x) 17.63 ms (❌ 10.02x slower) N/A
f32_device 1.73 ms (✅ 1.00x) 1.19 ms (✅ 1.45x faster) 9.76 ms (❌ 5.64x slower)

LeNet5(推理,批量大小 = 1,000)

autograph tch candle
bf16_host 1.81 s (✅ 1.00x) 193.60 毫秒 (🚀 9.37倍更快) N/A
f32_host 15.56 毫秒 (✅ 1.00倍) 9.46 毫秒 (✅ 1.64倍更快) 94.23 毫秒 (❌ 6.06倍更慢)
bf16_device 4.65 毫秒 (✅ 1.00倍) 48.63 毫秒 (❌ 10.46倍更慢) N/A
f32_device 4.65 毫秒 (✅ 1.00倍) 1.84 毫秒 (🚀 2.52倍更快) 10.81 毫秒 (❌ 2.33倍更慢)

查看神经网络基准测试。

许可证

双许可,与Rust项目兼容。

根据Apache License,版本2.0 https://apache.ac.cn/licenses/LICENSE-2.0 或MIT许可证 http://opensource.org/licenses/MIT,由您选择。此文件不得复制、修改或分发,除非符合这些条款。

贡献

除非您明确声明,否则您提交的任何贡献,根据Apache-2.0许可证定义,应与上述双许可一致,不附加任何额外条款或条件。

依赖项

~11–23MB
~390K SLoC