#神经网络 #深度学习 #机器学习 #PyTorch #Rust

rstorch

从零开始实现受PyTorch启发的神经网络框架的Rust实现

2个不稳定版本

0.2.0 2023年11月25日
0.1.0 2023年5月17日

#329 in 机器学习

MIT/Apache

77KB
2.5K SLoC

RsTorch

从零开始实现一个类似PyTorch API的深度学习框架的Rust实现。该项目仍处于早期阶段,尚未准备好用于生产。因此,API可能随时更改。

目前,该项目实现了最小可行产品,允许用户使用MNIST数据集训练序列模型。此外,它还提供自动从互联网下载的MNIST数据集。

安装

将以下内容添加到您的 Cargo.toml

[dependencies]
rstorch = "0.2.0"

或者如果您想从master分支使用最新版本

[dependencies]
rstorch = { git = "https://github.com/ferranSanchezLlado/rstorch.git" }

用法

如何使用此库和MNIST数据集训练模型的小示例

use rstorch::data::{DataLoader, SequentialSampler};
use rstorch::hub::MNIST;
use rstorch::prelude::*;
use rstorch::utils::{accuracy, flatten, normalize_zero_one, one_hot};
use rstorch::{CrossEntropyLoss, Identity, Linear, ReLU, Sequential, SGD};
use std::fs;
use std::path::PathBuf;

const BATCH_SIZE: usize = 32;
const EPOCHS: usize = 5;

fn main() {
    // Path that gets deleted by tests
    let path: PathBuf = ["data", "mnist"].iter().collect();

    let train_data = MNIST::new(path, true, true)
        .transform(|(x, y)| (flatten(normalize_zero_one(x)), one_hot(y, 10)));
    let sampler = SequentialSampler::new(train_data.len());
    let mut data_loader = DataLoader::new(train_data, BATCH_SIZE, true, sampler);

    let mut model = sequential!(
        Identity(),
        Linear(784, 100),
        ReLU(),
        Linear(100, 100),
        ReLU(),
        Linear(100, 10),
    );
    let mut loss = CrossEntropyLoss::new();
    let mut optim = SGD::new(0.01);

    for i in 0..EPOCHS {
        let n = data_loader.len() as f64;
        let mut total_loss = 0.0;
        let mut total_acc = 0.0;

        for (x, y) in data_loader.iter_array() {
            let pred = model.forward(x);
            let l = loss.forward(pred.clone(), y.clone());
            let acc = accuracy(pred, y);

            total_loss += l;
            total_acc += acc;

            model.backward(loss.backward());
            optim.step(&mut model);
        }

        let avg_loss = total_loss / n;
        let avg_acc = total_acc / n;
        println!("EPOCH {i}: Avarage loss {avg_loss} - Avarage accuracy {avg_acc}");
    }
}

许可协议

本项目采用 MIT许可证Apache许可证第2.0版,任选其一。

依赖关系

~2–17MB
~187K SLoC