5 个版本 (破坏性更新)
0.6.0 | 2022年4月3日 |
---|---|
0.5.0 | 2021年11月14日 |
0.4.0 | 2021年6月26日 |
0.3.0 | 2021年3月26日 |
0.2.0 | 2020年11月16日 |
#1138 in Rust 模式
12KB
337 行
tch-tensor-like: 为 tch-rs 派生 Tensor-like 类型
关于此软件包
如果你是 tch-rs 的用户,你可能曾经处理过类似以下复杂的模型输入类型。
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
在你将此类类型的批量输入馈送到模型之前,你必须将其移动到适当的设备。对每个类型的成员调用 tensor.to_device()
可能会很繁琐。TensorLike
派生宏正是为了帮助你解决这个问题。
use tch_tensor_like::TensorLike;
#[derive(TensorLike)]
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
通过派生该宏,你可以直接获得 to_device()
、to_kind()
和 shallow_clone()
方法。
let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
.to_kind(Kind::Float)
.shallow_clone();
对于非张量成员,你可以标记属性来克隆值。
#[derive(TensorLike)]
struct ModelInput {
// primitives are copied by default
pub number: i32,
// copy the field
#[tensor_like(copy)]
pub text: &'static str,
// clone the field
#[tensor_like(clone)]
pub desc: String,
}
用法
该软件包尚未发布到 crates.io。将仓库链接添加到你的项目中以包含此软件包。
[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }
许可协议
MIT 许可协议。请参阅 LICENSE 文件。
依赖关系
~7.5–10MB
~204K SLoC