#deep-learning #machine-learning #neural-network #pytorch

tinygrad

你喜欢PyTorch吗?你喜欢micrograd?你一定会喜欢tinygrad!💖

2个不稳定版本

0.1.0 2024年2月19日
0.0.0 2023年12月29日

#386 in 机器学习

MIT许可协议

15KB
86 代码行

✨️ tinygrad

Crates.io docs License

一个用于构建和训练神经网络的Rust crate。 tinygrad提供了一个简单的接口来定义张量、执行前向和反向传播,并实现点积和求和等基本操作。

🚀 快速开始

按照以下简单步骤开始使用tinygrad

  1. 通过将以下行添加到您的Cargo.toml文件中安装tinygrad crate
[dependencies]
tinygrad = "0.1.0"
  1. 使用TensorForwardBackward trait创建和操作张量
use ndarray::{array, Array1};
use tinygrad::{Tensor, Context, TensorTrait};

// Create a tensor
let value = array![1.0, 2.0, 3.0];
let tensor = Tensor::new(value);

// Perform forward and backward passes
let mut ctx = Context::new();
let result = tensor.forward(&mut ctx, vec![tensor.get_value()]);
tensor.backward(&mut ctx, array![1.0, 1.0, 1.0].view());
  1. 通过定义实现ForwardBackward trait的结构体来实现自定义操作
use ndarray::ArrayView1;
use tinygrad::{ForwardBackward, Context, TensorTrait};

// Example operation: Dot product
struct Dot;

impl ForwardBackward for Dot {
    fn forward(&self, _ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64 {
        let input = &inputs[0];
        let weight = &inputs[1];
        input.dot(weight)
    }

    fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>) {
        // Implement backward pass
        // ...
    }
}

🔧 使用示例

use ndarray::{array, Array1};
use tinygrad::{Tensor, Context, TensorTrait};

fn main() {
    let input = array![1.0, 2.0, 3.0];
    let weight = array![4.0, 5.0, 6.0];

    let input_tensor = Box::new(Tensor::new(input));
    let weight_tensor = Box::new(Tensor::new(weight));

    let dot_fn = Dot;
    let mut ctx = Context::new();

    let inputs = vec![
        input_tensor.get_value(),
        weight_tensor.get_value(),
    ];
    let output = dot_fn.forward(&mut ctx, inputs);

    println!("Dot product: {:?}", output);

    let grad_output = array![1.0, 1.0, 1.0];
    dot_fn.backward(&mut ctx, grad_output.view());

    let grad_input = &input_tensor.grad.clone();
    let grad_weight = &weight_tensor.grad.clone();

    println!("Gradient for input: {:?}", grad_input);
    println!("Gradient for weight: {:?}", grad_weight);
}

🧪 测试

使用以下命令运行tinygrad crate的测试

cargo test

🌐 GitHub仓库

您可以在GitHub上访问tinygrad crate的源代码。

🤝 贡献

欢迎贡献和反馈!如果您想贡献、报告问题或建议改进,请在GitHub上与项目互动。您的贡献有助于为社区改进此crate。

📘 文档

tinygrad的完整文档可在docs.rs上找到。

📄 许可协议

本项目采用MIT许可协议

依赖

~1.5MB
~25K SLoC