#tensor #enums #derive #model #image #complex #tch

tch-tensor-like

为 tch 矩阵的结构体或枚举类型派生方便的方法

5 个版本 (破坏性更新)

0.6.0 2022年4月3日
0.5.0 2021年11月14日
0.4.0 2021年6月26日
0.3.0 2021年3月26日
0.2.0 2020年11月16日

#1138 in Rust 模式

MIT 许可协议

12KB
337

tch-tensor-like: 为 tch-rs 派生 Tensor-like 类型

关于此软件包

如果你是 tch-rs 的用户,你可能曾经处理过类似以下复杂的模型输入类型。

struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

在你将此类类型的批量输入馈送到模型之前,你必须将其移动到适当的设备。对每个类型的成员调用 tensor.to_device() 可能会很繁琐。TensorLike 派生宏正是为了帮助你解决这个问题。

use tch_tensor_like::TensorLike;

#[derive(TensorLike)]
struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

通过派生该宏,你可以直接获得 to_device()to_kind()shallow_clone() 方法。

let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
                 .to_kind(Kind::Float)
                 .shallow_clone();

对于非张量成员,你可以标记属性来克隆值。

#[derive(TensorLike)]
struct ModelInput {
    // primitives are copied by default
    pub number: i32,

    // copy the field
    #[tensor_like(copy)]
    pub text: &'static str,

    // clone the field
    #[tensor_like(clone)]
    pub desc: String,
}

用法

该软件包尚未发布到 crates.io。将仓库链接添加到你的项目中以包含此软件包。

[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }

许可协议

MIT 许可协议。请参阅 LICENSE 文件。

依赖关系

~7.5–10MB
~204K SLoC