3 个版本
0.1.2 | 2024 年 2 月 5 日 |
---|---|
0.1.1 | 2023 年 5 月 30 日 |
0.1.0 | 2023 年 5 月 18 日 |
#955 in 机器学习
1.5MB
34K SLoC
使用 tch 与 PyTorch 交互的 Python 扩展
这个示例 crate 展示了如何使用 tch 编写一个 Python 扩展,通过 PyO3 操作 PyTorch 张量。
Python 扩展仅公开了一个函数,该函数将输入张量的值加一。相关代码如下
#[pyfunction]
fn add_one(tensor: PyTensor) -> PyResult<PyTensor> {
let tensor = tensor.f_add_scalar(1.0).map_err(wrap_tch_err)?;
Ok(PyTensor(tensor))
}
建议使用 f_
方法,以便在 tch
crate 中潜在的错误不会导致 Python 解释器崩溃。
编译扩展
要从 GitHub 仓库的根目录构建扩展并测试插件,请运行以下命令。这需要一个安装了适当 PyTorch 版本的 Python 环境。
LIBTORCH_USE_PYTORCH=1 cargo build && cp -f target/debug/libtch_ext.so tch_ext.so
python test.py
设置 LIBTORCH_USE_PYTORCH
将导致在 tch
中使用 Python 安装的 libtorch C++ 库,并确保使用正确的版本(如果 tch
使用与 Python 运行时不同的 libtorch 版本,可能会导致段错误)。
Colab 笔记本
基于 tch
的插件可以轻松地从 Colab 中使用(尽管下载所有 crate 并编译可能有点慢),请参阅这个 示例笔记本。
依赖项
~12–18MB
~271K SLoC