4 个版本 (2 个重大更改)

0.3.4 2022 年 9 月 10 日
0.3.3 2022 年 8 月 16 日
0.2.0 2022 年 8 月 7 日
0.1.1 2022 年 7 月 18 日

机器学习 类别中排名第 526

MIT 许可证

39KB
908 代码行

gradients

Crates.io version Docs

使用 custoscustos-math 的深度学习库。

外部 (C) 依赖:OpenCL、CUDA、nvrtc、cublas、BLAS 库 (OpenBLAS、Intel MKL 等)

安装

默认启用了两个功能

  • cuda ... 需要安装 CUDA、nvrtc 和 cublas
  • opencl ... 需要OpenCL

如果您禁用它们(添加 default-features = false 并不提供其他功能),则只能使用 CPU 设备。

对于所有功能配置,系统上都需要安装一个 BLAS 库。

[dependencies]
gradients = "0.3.4"

# to disable the default features (cuda, opencl) and use an own set of features:
#gradients = {version = "0.3.4", default-features = false, features=["opencl"]}

MNIST 示例

(如果此示例无法编译,请考虑查看 此处)

使用实现 NeuralNetwork 特性的结构体(通过 network 属性实现)来定义您想使用的层

use gradients::purpur::{CSVLoader, CSVReturn, Converter};
use gradients::OneHotMat;
use gradients::{
    correct_classes,
    nn::{cce, cce_grad},
    range, Adam, CLDevice, Linear, network, ReLU, Softmax,
};

#[network]
pub struct Network {
    lin1: Linear<784, 128>,
    relu1: ReLU,
    lin2: Linear<128, 10>,
    relu2: ReLU,
    lin3: Linear<10, 10>,
    softmax: Softmax,
}

加载 数据 并创建 Network 实例

您可以从 这里 下载 mnist 数据集。

// use cpu (no features enabled): let device = gradients::CPU::new().select();
// use cuda device (cuda feature enabled): let device = gradients::CudaDevice::new(0).unwrap().select();
// use opencl device (opencl feature enabled):
let device = CLDevice::new(0)?;

let mut net = Network::with_device(&device);

let loader = CSVLoader::new(true);
let loaded_data: CSVReturn<f32> = loader.load("PATH/TO/DATASET/mnist_train.csv")?;

let i = Matrix::from((
    &device,
    (loaded_data.sample_count, loaded_data.features),
    &loaded_data.x,
));
let i = i / 255.;

let y = Matrix::from((&device, (loaded_data.sample_count, 1), &loaded_data.y));
let y = y.onehot();

训练循环

let mut opt = Adam::new(0.01);

for epoch in range(200) {
    let preds = net.forward(&i);
    let correct_training = correct_classes(&loaded_data.y.as_usize(), &preds) as f32;

    let loss = cce(&device, &preds, &y);
    println!(
        "epoch: {epoch}, loss: {loss}, training_acc: {acc}",
        acc = correct_training / loaded_data.sample_count() as f32
    );

    let grad = cce_grad(&device, &preds, &y);
    net.backward(&grad);
    opt.step(&device, net.params());
}

依赖项

~14–23MB
~165K SLoC