15 个版本 (重大变更)
0.13.2 | 2024 年 5 月 3 日 |
---|---|
0.12.1 | 2024 年 2 月 1 日 |
0.11.1 | 2023 年 12 月 4 日 |
0.10.0 | 2023 年 10 月 24 日 |
0.3.0 | 2022 年 11 月 20 日 |
在 机器学习 类别中排名第 173
每月下载量 8,057
在 15 个crate中(直接使用3个)
1MB
17K SLoC
Burn 火焰后端
Burn 火焰后端
此 crate 为 Burn 提供了使用 tch-rs
crate 的 Torch 后端,该 crate 为 PyTorch C++ API 提供了 Rust 接口。
后端支持 CPU(多线程)、CUDA(多个 GPU)和 MPS 设备(MacOS)。
安装
tch-rs
需要 C++ PyTorch 库(LibTorch)在您的系统上可用。
默认情况下,为满足 tch-rs
的要求,安装了 LibTorch v2.2.0 的 CPU 分发。
CUDA
要安装最新兼容的 CUDA 分发,在通过 cargo
获取 tch-rs
依赖项之前,设置 TORCH_CUDA_VERSION
环境变量。
export TORCH_CUDA_VERSION=cu121
在 Windows 上
$Env:TORCH_CUDA_VERSION = "cu121"
例如,第一次运行验证样本可以执行以下命令
export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release
重要: 确保您的驱动程序版本与所选 CUDA 版本兼容。由于 LibTorch 附带了适当的 CUDA 运行时,因此不需要安装 CUDA 工具包。建议使用最新版本的驱动程序,但您始终可以查看 工具包驱动程序版本表 或 最低要求的驱动程序版本(功能有限,可能无法与所有操作兼容)。
安装完成后,您应该能够构建/运行您的项目。您还可以通过运行适当的 cpu
、cuda
或 mps
样本来验证您的安装。
cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release
注意:目前没有可自动下载的 MPS 发行版,请查看 手动说明。
手动下载
要使用不同的 LibTorch 发行版安装 tch-rs
,您必须手动下载所需的 LibTorch 发行版。具体说明在每个平台的下面章节中详细说明。
计算平台 | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
---|---|---|---|---|---|---|---|---|
CPU | 是 | 否 | 是 | 是 | 是 | 是 | 是 | 否 |
CUDA | 是 [1] | 是 | 是 | 否 | 是 | 否 | 否 | 否 |
Metal (MPS) | 否 | 是 | 否 | 是 | 否 | 否 | 否 | 否 |
Vulkan | 是 | 是 | 是 | 是 | 是 | 是 | 否 | 否 |
LibTorch CUDA 发行版也包含 CPU 支持。
CPU
🐧 Linux
首先,下载 LibTorch CPU 发行版。
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,使用 LIBTORCH
和 LD_LIBRARY_PATH
环境变量指向该安装。
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
🍎 Mac
首先,下载 LibTorch CPU 发行版。
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,使用 LIBTORCH
和 DYLD_LIBRARY_PATH
环境变量指向该安装。
export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH
🪟 Windows
首先,下载 LibTorch CPU 发行版。
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,设置 LIBTORCH
环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
CUDA
LibTorch 2.2.0 当前包含 CUDA 11.8 或 12.1 运行时的二进制发行版。手动安装说明如下。
CUDA 11.8
🐧 Linux
首先,下载 LibTorch CUDA 11.8 发行版。
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,使用 LIBTORCH
和 LD_LIBRARY_PATH
环境变量指向该安装。
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
注意:确保您的 CUDA 安装在您的 PATH
和 LD_LIBRARY_PATH
中。
🪟 Windows
首先,下载 LibTorch CUDA 11.8 发行版。
wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,设置 LIBTORCH
环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
CUDA 12.1
🐧 Linux
首先,下载 LibTorch CUDA 12.1 发行版。
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,使用 LIBTORCH
和 LD_LIBRARY_PATH
环境变量指向该安装。
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
注意:确保您的 CUDA 安装在您的 PATH
和 LD_LIBRARY_PATH
中。
🪟 Windows
首先,下载 LibTorch CUDA 12.1 发行版。
wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
然后,在构建 burn-tch
或依赖它的 crate 之前,设置 LIBTORCH
环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
Metal (MPS)
目前没有官方支持 MPS 的 LibTorch 发行版,因此最简单的替代方案是使用 PyTorch 安装。这需要一个 Python 安装。
注意:MPS 加速在 MacOS 12.3+ 上可用。
pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH
示例用法
对于简单的示例,请查看 src/bin/
中的任何测试程序。每个程序设置要使用的设备并执行简单的元素级加法。
对于使用 tch
后端更完整的示例,请参阅 Burn mnist 示例。
依赖项
~16MB
~311K SLoC