#pytorch #deep-learning #machine-learning

tch-ext

使用 tch 与 PyTorch 交互的 Python 扩展示例

3 个版本

0.1.2 2024 年 2 月 5 日
0.1.1 2023 年 5 月 30 日
0.1.0 2023 年 5 月 18 日

#955 in 机器学习

MIT/Apache

1.5MB
34K SLoC

C++ 18K SLoC // 0.0% comments Rust 17K SLoC // 0.0% comments Python 6 SLoC

使用 tch 与 PyTorch 交互的 Python 扩展

这个示例 crate 展示了如何使用 tch 编写一个 Python 扩展,通过 PyO3 操作 PyTorch 张量。

Python 扩展仅公开了一个函数,该函数将输入张量的值加一。相关代码如下

#[pyfunction]
fn add_one(tensor: PyTensor) -> PyResult<PyTensor> {
    let tensor = tensor.f_add_scalar(1.0).map_err(wrap_tch_err)?;
    Ok(PyTensor(tensor))
}

建议使用 f_ 方法,以便在 tch crate 中潜在的错误不会导致 Python 解释器崩溃。

编译扩展

要从 GitHub 仓库的根目录构建扩展并测试插件,请运行以下命令。这需要一个安装了适当 PyTorch 版本的 Python 环境。

LIBTORCH_USE_PYTORCH=1 cargo build && cp -f target/debug/libtch_ext.so tch_ext.so
python test.py

设置 LIBTORCH_USE_PYTORCH 将导致在 tch 中使用 Python 安装的 libtorch C++ 库,并确保使用正确的版本(如果 tch 使用与 Python 运行时不同的 libtorch 版本,可能会导致段错误)。

Colab 笔记本

基于 tch 的插件可以轻松地从 Colab 中使用(尽管下载所有 crate 并编译可能有点慢),请参阅这个 示例笔记本

依赖项

~12–18MB
~271K SLoC