2个不稳定版本
0.1.0 | 2024年2月19日 |
---|---|
0.0.0 | 2023年12月29日 |
#386 in 机器学习
15KB
86 代码行
✨️ tinygrad
一个用于构建和训练神经网络的Rust crate。 tinygrad
提供了一个简单的接口来定义张量、执行前向和反向传播,并实现点积和求和等基本操作。
🚀 快速开始
按照以下简单步骤开始使用tinygrad
库
- 通过将以下行添加到您的
Cargo.toml
文件中安装tinygrad
crate
[dependencies]
tinygrad = "0.1.0"
- 使用
Tensor
和ForwardBackward
trait创建和操作张量
use ndarray::{array, Array1};
use tinygrad::{Tensor, Context, TensorTrait};
// Create a tensor
let value = array![1.0, 2.0, 3.0];
let tensor = Tensor::new(value);
// Perform forward and backward passes
let mut ctx = Context::new();
let result = tensor.forward(&mut ctx, vec![tensor.get_value()]);
tensor.backward(&mut ctx, array![1.0, 1.0, 1.0].view());
- 通过定义实现
ForwardBackward
trait的结构体来实现自定义操作
use ndarray::ArrayView1;
use tinygrad::{ForwardBackward, Context, TensorTrait};
// Example operation: Dot product
struct Dot;
impl ForwardBackward for Dot {
fn forward(&self, _ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64 {
let input = &inputs[0];
let weight = &inputs[1];
input.dot(weight)
}
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>) {
// Implement backward pass
// ...
}
}
🔧 使用示例
use ndarray::{array, Array1};
use tinygrad::{Tensor, Context, TensorTrait};
fn main() {
let input = array![1.0, 2.0, 3.0];
let weight = array![4.0, 5.0, 6.0];
let input_tensor = Box::new(Tensor::new(input));
let weight_tensor = Box::new(Tensor::new(weight));
let dot_fn = Dot;
let mut ctx = Context::new();
let inputs = vec![
input_tensor.get_value(),
weight_tensor.get_value(),
];
let output = dot_fn.forward(&mut ctx, inputs);
println!("Dot product: {:?}", output);
let grad_output = array![1.0, 1.0, 1.0];
dot_fn.backward(&mut ctx, grad_output.view());
let grad_input = &input_tensor.grad.clone();
let grad_weight = &weight_tensor.grad.clone();
println!("Gradient for input: {:?}", grad_input);
println!("Gradient for weight: {:?}", grad_weight);
}
🧪 测试
使用以下命令运行tinygrad
crate的测试
cargo test
🌐 GitHub仓库
您可以在GitHub上访问tinygrad
crate的源代码。
🤝 贡献
欢迎贡献和反馈!如果您想贡献、报告问题或建议改进,请在GitHub上与项目互动。您的贡献有助于为社区改进此crate。
📘 文档
tinygrad
的完整文档可在docs.rs上找到。
📄 许可协议
本项目采用MIT许可协议。
依赖
~1.5MB
~25K SLoC