4 个版本 (2 个重大更改)
0.3.4 | 2022 年 9 月 10 日 |
---|---|
0.3.3 |
|
0.2.0 | 2022 年 8 月 7 日 |
0.1.1 | 2022 年 7 月 18 日 |
在 机器学习 类别中排名第 526
39KB
908 代码行
gradients
使用 custos 和 custos-math 的深度学习库。
外部 (C) 依赖:OpenCL、CUDA、nvrtc、cublas、BLAS 库 (OpenBLAS、Intel MKL 等)
安装
默认启用了两个功能
- cuda ... 需要安装 CUDA、nvrtc 和 cublas
- opencl ... 需要OpenCL
如果您禁用它们(添加 default-features = false
并不提供其他功能),则只能使用 CPU 设备。
对于所有功能配置,系统上都需要安装一个 BLAS 库。
[dependencies]
gradients = "0.3.4"
# to disable the default features (cuda, opencl) and use an own set of features:
#gradients = {version = "0.3.4", default-features = false, features=["opencl"]}
MNIST 示例
(如果此示例无法编译,请考虑查看 此处)
使用实现 NeuralNetwork 特性的结构体(通过 network
属性实现)来定义您想使用的层
use gradients::purpur::{CSVLoader, CSVReturn, Converter};
use gradients::OneHotMat;
use gradients::{
correct_classes,
nn::{cce, cce_grad},
range, Adam, CLDevice, Linear, network, ReLU, Softmax,
};
#[network]
pub struct Network {
lin1: Linear<784, 128>,
relu1: ReLU,
lin2: Linear<128, 10>,
relu2: ReLU,
lin3: Linear<10, 10>,
softmax: Softmax,
}
加载 数据 并创建 Network 实例
您可以从 这里 下载 mnist 数据集。
// use cpu (no features enabled): let device = gradients::CPU::new().select();
// use cuda device (cuda feature enabled): let device = gradients::CudaDevice::new(0).unwrap().select();
// use opencl device (opencl feature enabled):
let device = CLDevice::new(0)?;
let mut net = Network::with_device(&device);
let loader = CSVLoader::new(true);
let loaded_data: CSVReturn<f32> = loader.load("PATH/TO/DATASET/mnist_train.csv")?;
let i = Matrix::from((
&device,
(loaded_data.sample_count, loaded_data.features),
&loaded_data.x,
));
let i = i / 255.;
let y = Matrix::from((&device, (loaded_data.sample_count, 1), &loaded_data.y));
let y = y.onehot();
训练循环
let mut opt = Adam::new(0.01);
for epoch in range(200) {
let preds = net.forward(&i);
let correct_training = correct_classes(&loaded_data.y.as_usize(), &preds) as f32;
let loss = cce(&device, &preds, &y);
println!(
"epoch: {epoch}, loss: {loss}, training_acc: {acc}",
acc = correct_training / loaded_data.sample_count() as f32
);
let grad = cce_grad(&device, &preds, &y);
net.backward(&grad);
opt.step(&device, net.params());
}
依赖项
~14–23MB
~165K SLoC