2 个版本
0.1.1 | 2024 年 4 月 1 日 |
---|---|
0.1.0 | 2024 年 3 月 2 日 |
#248 in 机器学习
每月 24 次下载
21KB
411 行
ZeNu
ZeNu 是一个用 Rust 编写的简单直观的深度学习库。它提供了创建和训练神经网络的基本模块,注重易用性和灵活性。
特性
- 高级 API 用于定义和训练神经网络
- 与 MNIST 等流行数据集的集成
- 模块化设计,易于扩展
- 使用底层 zenu-matrix 和 zenu-autograd 库进行高效计算
入门指南
要在 Rust 项目中使用 ZeNu,请将以下内容添加到您的 Cargo.toml
文件中
[dependencies]
zenu = "0.1.0"
以下是一个使用 ZeNu 定义和训练模型的简单示例
use zenu::{
dataset::{train_val_split, DataLoader, Dataset},
mnist::minist_dataset,
update_parameters, Model,
};
use zenu_autograd::{
creator::from_vec::from_vec,
functions::{activation::sigmoid::sigmoid, loss::cross_entropy::cross_entropy},
Variable,
};
use zenu_layer::{layers::linear::Linear, Layer};
use zenu_matrix::{
matrix::{IndexItem, ToViewMatrix},
operation::max::MaxIdx,
};
use zenu_optimizer::sgd::SGD;
// Define your model
struct SingleLayerModel {
linear: Linear<f32>,
}
impl SingleLayerModel {
fn new() -> Self {
let mut linear = Linear::new(784, 10);
linear.init_parameters(None);
Self { linear }
}
}
impl Model<f32> for SingleLayerModel {
fn predict(&self, inputs: &[Variable<f32>]) -> Variable<f32> {
let x = &inputs[0];
let x = self.linear.call(x.clone());
sigmoid(x)
}
}
// Define your dataset
struct MnistDataset {
data: Vec<(Vec<u8>, u8)>,
}
impl Dataset<f32> for MnistDataset {
type Item = (Vec<u8>, u8);
fn item(&self, item: usize) -> Vec<Variable<f32>> {
// ... Implement your dataset logic
}
fn len(&self) -> usize {
self.data.len()
}
fn all_data(&mut self) -> &mut [Self::Item] {
&mut self.data as &mut [Self::Item]
}
}
fn main() {
// Load and prepare your data
let (train, test) = minist_dataset().unwrap();
let (train, val) = train_val_split(&train, 0.8, true);
let test_dataloader = DataLoader::new(MnistDataset { data: test }, 1);
// Create your model and optimizer
let sgd = SGD::new(0.01);
let model = SingleLayerModel::new();
// Train your model
for epoch in 0..10 {
// ... Implement your training loop
}
// Evaluate your model
let mut test_loss = 0.;
let mut num_iter_test = 0;
let mut correct = 0;
let mut total = 0;
for batch in test_dataloader {
// ... Implement your evaluation logic
}
println!("Accuracy: {}", correct as f32 / total as f32);
println!("Test Loss: {}", test_loss / num_iter_test as f32);
}
有关更多详细信息和方法,请参阅 文档。
许可证
ZeNu 根据 MIT 许可证 许可。
依赖项
~6–21MB
~288K SLoC