5个版本
0.1.5 | 2023年3月19日 |
---|---|
0.1.4 | 2023年3月17日 |
0.1.3 | 2023年3月2日 |
0.1.2 | 2022年11月26日 |
0.1.0 | 2022年11月20日 |
#794 in 数学
每月下载量:42
91KB
2.5K SLoC
microtensor
张量操作的自动微分。
需要Rust nightly版本。
特性
-
安全的自动微分 — 非可微操作返回一个无法反向传播的独立类型,在编译时揭示计算图中的差距。
-
广播 — 在大多数操作中,形状不同但兼容的张量会自动广播到匹配的维度。
-
任意内部类型 — 张量可以存储几乎任何数据类型,并为满足scalar::Real的任何内部类型计算梯度。
-
零拷贝视图 — 在大多数情况下,张量可以被切片、索引、重塑、转置和广播,而不实际复制任何数据。
-
图回收 — 通过跟踪即时计算创建的计算图可以在稍后时间使用新的输入数据进行重新评估。它们还可以在无权访问原始代码的情况下进行序列化和在其他地方加载。
示例
评估和最小化非线性函数
use microtensor::{prelude::*, Tensor};
// Create variables from tensors
let w = Tensor::randn(&[2, 16]).trained();
let b = Tensor::zeros(&[16]).trained();
for _ in 0..100 {
// Do some computation
let x = Tensor::vec(&[1.0, 2.0]).tracked();
let loss = ((x.mm(&w) + &b).sigmoid() - 0.5).sqr().mean(0);
// Compute gradients
loss.backward();
// Nudge w and b in order to minimize loss
for mut param in loss.parameters() {
param -= param.grad().unwrap() * 0.01;
}
// Reset gradients
loss.reset();
}
自动广播
use microtensor::{prelude::*, Tensor};
let a = Tensor::arrange(&[2, 16], 0., 1.);
let b = Tensor::ones(&[2]);
let c = &a - b.unsqueeze(-1) + 1.;
assert_eq!(a, c);
泛型返回类型
use microtensor::{prelude::*, Tensor};
let t = Tensor::<f32>::randn(&[16]);
let _a: u8 = t.argmax(0).item();
let _b: u16 = t.argmax(0).item(); // argmax will produce a Tensor<u16> here
可选特性
一些特性可以在您的Cargo.toml
中切换。
unsafe
(默认) — 使用[matrixmultiply]crate加速矩阵数学。threading
(默认) — 线程安全及在批处理维度上的多线程操作。
更多示例
查看/examples
文件夹以获取更多示例代码。
许可证
MIT
依赖项
~2.3–9.5MB
~77K SLoC