15 个版本 (重大变更)

0.13.2 2024 年 5 月 3 日
0.12.1 2024 年 2 月 1 日
0.11.1 2023 年 12 月 4 日
0.10.0 2023 年 10 月 24 日
0.3.0 2022 年 11 月 20 日

机器学习 类别中排名第 173

Download history 2596/week @ 2024-05-03 2076/week @ 2024-05-10 1996/week @ 2024-05-17 3901/week @ 2024-05-24 2708/week @ 2024-05-31 4993/week @ 2024-06-07 2512/week @ 2024-06-14 1741/week @ 2024-06-21 2875/week @ 2024-06-28 1989/week @ 2024-07-05 1527/week @ 2024-07-12 1542/week @ 2024-07-19 1524/week @ 2024-07-26 3112/week @ 2024-08-02 1781/week @ 2024-08-09 1425/week @ 2024-08-16

每月下载量 8,057
15crate中(直接使用3个)

MIT/Apache

1MB
17K SLoC

Burn 火焰后端

Burn 火焰后端

Current Crates.io Version license

此 crate 为 Burn 提供了使用 tch-rs crate 的 Torch 后端,该 crate 为 PyTorch C++ API 提供了 Rust 接口。

后端支持 CPU(多线程)、CUDA(多个 GPU)和 MPS 设备(MacOS)。

安装

tch-rs 需要 C++ PyTorch 库(LibTorch)在您的系统上可用。

默认情况下,为满足 tch-rs 的要求,安装了 LibTorch v2.2.0 的 CPU 分发。

CUDA

要安装最新兼容的 CUDA 分发,在通过 cargo 获取 tch-rs 依赖项之前,设置 TORCH_CUDA_VERSION 环境变量。

export TORCH_CUDA_VERSION=cu121

在 Windows 上

$Env:TORCH_CUDA_VERSION = "cu121"

例如,第一次运行验证样本可以执行以下命令

export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release

重要: 确保您的驱动程序版本与所选 CUDA 版本兼容。由于 LibTorch 附带了适当的 CUDA 运行时,因此不需要安装 CUDA 工具包。建议使用最新版本的驱动程序,但您始终可以查看 工具包驱动程序版本表最低要求的驱动程序版本(功能有限,可能无法与所有操作兼容)。


安装完成后,您应该能够构建/运行您的项目。您还可以通过运行适当的 cpucudamps 样本来验证您的安装。

cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release

注意:目前没有可自动下载的 MPS 发行版,请查看 手动说明

手动下载

要使用不同的 LibTorch 发行版安装 tch-rs,您必须手动下载所需的 LibTorch 发行版。具体说明在每个平台的下面章节中详细说明。

计算平台 CPU GPU Linux MacOS Windows Android iOS WASM
CPU
CUDA [1]
Metal (MPS)
Vulkan

LibTorch CUDA 发行版也包含 CPU 支持。

CPU

🐧 Linux

首先,下载 LibTorch CPU 发行版。

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,使用 LIBTORCHLD_LIBRARY_PATH 环境变量指向该安装。

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

🍎 Mac

首先,下载 LibTorch CPU 发行版。

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,使用 LIBTORCHDYLD_LIBRARY_PATH 环境变量指向该安装。

export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH

🪟 Windows

首先,下载 LibTorch CPU 发行版。

wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,设置 LIBTORCH 环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA

LibTorch 2.2.0 当前包含 CUDA 11.8 或 12.1 运行时的二进制发行版。手动安装说明如下。

CUDA 11.8

🐧 Linux

首先,下载 LibTorch CUDA 11.8 发行版。

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,使用 LIBTORCHLD_LIBRARY_PATH 环境变量指向该安装。

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意:确保您的 CUDA 安装在您的 PATHLD_LIBRARY_PATH 中。


🪟 Windows

首先,下载 LibTorch CUDA 11.8 发行版。

wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,设置 LIBTORCH 环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA 12.1

🐧 Linux

首先,下载 LibTorch CUDA 12.1 发行版。

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,使用 LIBTORCHLD_LIBRARY_PATH 环境变量指向该安装。

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

注意:确保您的 CUDA 安装在您的 PATHLD_LIBRARY_PATH 中。


🪟 Windows

首先,下载 LibTorch CUDA 12.1 发行版。

wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

然后,在构建 burn-tch 或依赖它的 crate 之前,设置 LIBTORCH 环境变量并将库添加到您的路径中,如下面的 PowerShell 命令所示。

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

Metal (MPS)

目前没有官方支持 MPS 的 LibTorch 发行版,因此最简单的替代方案是使用 PyTorch 安装。这需要一个 Python 安装。

注意:MPS 加速在 MacOS 12.3+ 上可用。

pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH

示例用法

对于简单的示例,请查看 src/bin/ 中的任何测试程序。每个程序设置要使用的设备并执行简单的元素级加法。

对于使用 tch 后端更完整的示例,请参阅 Burn mnist 示例

依赖项

~16MB
~311K SLoC