2个不稳定版本
0.2.0 | 2023年11月25日 |
---|---|
0.1.0 | 2023年5月17日 |
#329 in 机器学习
77KB
2.5K SLoC
RsTorch
从零开始实现一个类似PyTorch API的深度学习框架的Rust实现。该项目仍处于早期阶段,尚未准备好用于生产。因此,API可能随时更改。
目前,该项目实现了最小可行产品,允许用户使用MNIST数据集训练序列模型。此外,它还提供自动从互联网下载的MNIST数据集。
安装
将以下内容添加到您的 Cargo.toml
[dependencies]
rstorch = "0.2.0"
或者如果您想从master分支使用最新版本
[dependencies]
rstorch = { git = "https://github.com/ferranSanchezLlado/rstorch.git" }
用法
如何使用此库和MNIST数据集训练模型的小示例
use rstorch::data::{DataLoader, SequentialSampler};
use rstorch::hub::MNIST;
use rstorch::prelude::*;
use rstorch::utils::{accuracy, flatten, normalize_zero_one, one_hot};
use rstorch::{CrossEntropyLoss, Identity, Linear, ReLU, Sequential, SGD};
use std::fs;
use std::path::PathBuf;
const BATCH_SIZE: usize = 32;
const EPOCHS: usize = 5;
fn main() {
// Path that gets deleted by tests
let path: PathBuf = ["data", "mnist"].iter().collect();
let train_data = MNIST::new(path, true, true)
.transform(|(x, y)| (flatten(normalize_zero_one(x)), one_hot(y, 10)));
let sampler = SequentialSampler::new(train_data.len());
let mut data_loader = DataLoader::new(train_data, BATCH_SIZE, true, sampler);
let mut model = sequential!(
Identity(),
Linear(784, 100),
ReLU(),
Linear(100, 100),
ReLU(),
Linear(100, 10),
);
let mut loss = CrossEntropyLoss::new();
let mut optim = SGD::new(0.01);
for i in 0..EPOCHS {
let n = data_loader.len() as f64;
let mut total_loss = 0.0;
let mut total_acc = 0.0;
for (x, y) in data_loader.iter_array() {
let pred = model.forward(x);
let l = loss.forward(pred.clone(), y.clone());
let acc = accuracy(pred, y);
total_loss += l;
total_acc += acc;
model.backward(loss.backward());
optim.step(&mut model);
}
let avg_loss = total_loss / n;
let avg_acc = total_acc / n;
println!("EPOCH {i}: Avarage loss {avg_loss} - Avarage accuracy {avg_acc}");
}
}
许可协议
本项目采用 MIT许可证 或 Apache许可证第2.0版,任选其一。
依赖关系
~2–17MB
~187K SLoC