#tensor #derive #macro #field #complex #tch-tensor-like #implementations

tch-tensor-like-derive

tch-tensor-like crate使用的派生宏实现

5个版本

0.2.0 2022年4月3日
0.1.3 2021年11月14日
0.1.2 2021年6月26日
0.1.1 2021年3月26日
0.1.0 2020年11月16日

#91 in #complex


tch-tensor-like中使用

MIT许可证

17KB
362 行代码(不含注释)

tch-tensor-like: 为tch-rs实现类似Tensor的类型

关于此crate

如果你是tch-rs的用户,你可能曾经处理过类似以下复杂的模型输入类型。

struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

在将此类类型的批输入馈送到模型之前,你必须将其移动到适当的设备。对于类型中的每个成员调用 tensor.to_device() 可能会很繁琐。此时,TensorLike 派生宏就会出现。

use tch_tensor_like::TensorLike;

#[derive(TensorLike)]
struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

通过派生宏,你将获得 to_device()to_kind()shallow_clone() 等功能。

let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
                 .to_kind(Kind::Float)
                 .shallow_clone();

对于非tensor成员,你可以标记属性来克隆值。

#[derive(TensorLike)]
struct ModelInput {
    // primitives are copied by default
    pub number: i32,

    // copy the field
    #[tensor_like(copy)]
    pub text: &'static str,

    // clone the field
    #[tensor_like(clone)]
    pub desc: String,
}

使用方法

此crate尚未发布到crates.io。将仓库链接添加到项目中以包含此crate。

[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }

许可证

MIT许可证。请参阅LICENSE文件。

依赖关系

~1.5MB
~35K SLoC