5个版本
0.2.0 | 2022年4月3日 |
---|---|
0.1.3 | 2021年11月14日 |
0.1.2 | 2021年6月26日 |
0.1.1 | 2021年3月26日 |
0.1.0 | 2020年11月16日 |
#91 in #complex
在tch-tensor-like中使用
17KB
362 行代码(不含注释)
tch-tensor-like: 为tch-rs实现类似Tensor的类型
关于此crate
如果你是tch-rs的用户,你可能曾经处理过类似以下复杂的模型输入类型。
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
在将此类类型的批输入馈送到模型之前,你必须将其移动到适当的设备。对于类型中的每个成员调用 tensor.to_device()
可能会很繁琐。此时,TensorLike
派生宏就会出现。
use tch_tensor_like::TensorLike;
#[derive(TensorLike)]
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
通过派生宏,你将获得 to_device()
,to_kind()
和 shallow_clone()
等功能。
let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
.to_kind(Kind::Float)
.shallow_clone();
对于非tensor成员,你可以标记属性来克隆值。
#[derive(TensorLike)]
struct ModelInput {
// primitives are copied by default
pub number: i32,
// copy the field
#[tensor_like(copy)]
pub text: &'static str,
// clone the field
#[tensor_like(clone)]
pub desc: String,
}
使用方法
此crate尚未发布到crates.io。将仓库链接添加到项目中以包含此crate。
[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }
许可证
MIT许可证。请参阅LICENSE文件。
依赖关系
~1.5MB
~35K SLoC