2 个版本

0.1.1 2024 年 4 月 1 日
0.1.0 2024 年 3 月 2 日

#248 in 机器学习

每月 24 次下载

MIT 许可证

21KB
411

ZeNu

ZeNu 是一个用 Rust 编写的简单直观的深度学习库。它提供了创建和训练神经网络的基本模块,注重易用性和灵活性。

特性

  • 高级 API 用于定义和训练神经网络
  • 与 MNIST 等流行数据集的集成
  • 模块化设计,易于扩展
  • 使用底层 zenu-matrix 和 zenu-autograd 库进行高效计算

入门指南

要在 Rust 项目中使用 ZeNu,请将以下内容添加到您的 Cargo.toml 文件中

[dependencies]
zenu = "0.1.0"

以下是一个使用 ZeNu 定义和训练模型的简单示例

use zenu::{
    dataset::{train_val_split, DataLoader, Dataset},
    mnist::minist_dataset,
    update_parameters, Model,
};
use zenu_autograd::{
    creator::from_vec::from_vec,
    functions::{activation::sigmoid::sigmoid, loss::cross_entropy::cross_entropy},
    Variable,
};
use zenu_layer::{layers::linear::Linear, Layer};
use zenu_matrix::{
    matrix::{IndexItem, ToViewMatrix},
    operation::max::MaxIdx,
};
use zenu_optimizer::sgd::SGD;

// Define your model
struct SingleLayerModel {
    linear: Linear<f32>,
}

impl SingleLayerModel {
    fn new() -> Self {
        let mut linear = Linear::new(784, 10);
        linear.init_parameters(None);
        Self { linear }
    }
}

impl Model<f32> for SingleLayerModel {
    fn predict(&self, inputs: &[Variable<f32>]) -> Variable<f32> {
        let x = &inputs[0];
        let x = self.linear.call(x.clone());
        sigmoid(x)
    }
}

// Define your dataset
struct MnistDataset {
    data: Vec<(Vec<u8>, u8)>,
}

impl Dataset<f32> for MnistDataset {
    type Item = (Vec<u8>, u8);

    fn item(&self, item: usize) -> Vec<Variable<f32>> {
        // ... Implement your dataset logic
    }

    fn len(&self) -> usize {
        self.data.len()
    }

    fn all_data(&mut self) -> &mut [Self::Item] {
        &mut self.data as &mut [Self::Item]
    }
}

fn main() {
    // Load and prepare your data
    let (train, test) = minist_dataset().unwrap();
    let (train, val) = train_val_split(&train, 0.8, true);

    let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);

    // Create your model and optimizer
    let sgd = SGD::new(0.01);
    let model = SingleLayerModel::new();

    // Train your model
    for epoch in 0..10 {
        // ... Implement your training loop
    }

    // Evaluate your model
    let mut test_loss = 0.;
    let mut num_iter_test = 0;
    let mut correct = 0;
    let mut total = 0;
    for batch in test_dataloader {
        // ... Implement your evaluation logic
    }

    println!("Accuracy: {}", correct as f32 / total as f32);
    println!("Test Loss: {}", test_loss / num_iter_test as f32);
}

有关更多详细信息和方法,请参阅 文档

许可证

ZeNu 根据 MIT 许可证 许可。

依赖项

~6–21MB
~288K SLoC