6 个版本

0.7.3 2024年7月13日
0.7.2 2024年3月19日
0.7.0 2024年2月20日
0.6.1 2024年1月26日

#512科学

Download history 134/week @ 2024-07-12 4/week @ 2024-07-19 45/week @ 2024-07-26 5/week @ 2024-08-02

56 每月下载量

Apache-2.0 OR MIT

600KB
8K SLoC

Kyanite

Crates.io kn-graph Crates.io kn-cuda-sys Crates.io kn-cuda-eval docs.rs CI status

概述

Kyanite 是一个用 Rust 编写/为 Rust 编写的神经网络推理库。它可以使用 cuda/cudnn/cublas 在 CPU 或 Nvidia GPU 上运行 ONNX 文件。

它足够通用,可以运行各种网络,它已经与以下网络进行了测试

  • 简单的全连接网络
  • 基于 ResNet 的 CNN
  • LLaMA 的大型语言模型
  • 图像生成模型,如 Stable Diffusion。有关演示,请参阅 stable_diffusion 示例,位于 kn-runtime 包中。

该框架由以下包组成

  • kn-graph:核心包,包含中间表示和 CPU 执行器。
  • kn-cuda-sys:Cuda FFI 绑定,使用 rust-bindgen 生成。
  • kn-cuda-eval:Cuda 执行器和规划器。
  • kn-runtime:围绕其他包的包装器,允许在运行时选择 CPU 和 GPU 执行。
  • kn-python:运行时包的实验性 Python 封装,使用 PyO3

快速演示

// Graph operations (using kn-graph)
// Load on onnx file into a graph
let graph = load_graph_from_onnx_path("test.onnx", false)?;
// Optimize the graph
let graph = optimize_graph(&graph, Default::default());
// Render the graph as an svg file
graph_to_svg("test.svg", &graph, false, false)?;

// Build the inputs
let batch_size = 8;
let inputs = [DTensor::F32(Tensor::zeros(IxDyn(&[batch_size, 16])))];

// CPU: (using kn-graph)
// just evaluate the graph
let outputs: Vec<DTensor> = cpu_eval_graph(&graph, batch_size, &inputs);

// GPU: (using kn-cuda-eval)
// build an executor
let device = CudaDevice::new(0).unwrap();
let mut executor = CudaExecutor::new(device, &graph, batch_size);
// run the executor on the inputs
let outputs: &[DTensor] = executor.evaluate(&inputs);

// Runtime device selection: (using kn-runtime)
let device = Device::best();
let mut prepared = device.prepare(graph, batch_size);
let outputs: Vec<DTensor> = prepared.eval( & inputs);

系统需求

要使用 CUDA 包,需要在系统上安装适当的库;它们不会自动下载

  • CUDA(包括 CUDA、cuBLAS、NVRTC):安装程序,遵循说明。确保环境变量 CUDA_PATH 指向安装的根目录(即,CUDA_PATH/bin/ 应该存在)。
  • cuDNN: 存档文件,需要解压缩到您选择的目录。如果您选择与CUDA_PATH相同的目录,则无需进行其他操作。否则,将环境变量CUDNN_PATH设置为cuDNN安装的根目录(即CUDNN_PATH/bin应存在)。

该项目已在CUDA v12.2和cuDNN版本v8.9.5上进行了测试。较新版本可能也可以工作,但这是没有保证的,因为CUDA有时会更改某些函数的名称或删除它们。

内部结构

典型的流程如图1所示。图2显示了在简单的NN架构上运行此流程的结果。

NN inference diagram

conv_bn_sm_flow.svg

图 IR

中心是图IR,神经网络图的中间表示。

该结构是一个SSA风格的定向无环图,其中节点是有形状、数据类型和计算它的操作的值。这些值是抽象的;它们还没有步长或内存位置。

操作类似于其他框架,但尽可能地保持正交。一些示例操作:卷积、矩阵乘法、重塑、广播、切片、一元、二元、归约、softmax等。有关图操作的完整列表,请参阅文档

可以使用图构建器API直接在代码中构建图,但为了方便起见,存在一个ONNX加载器。它可以读取ONNX文件并将支持的子集操作转换为IR支持的操作。

由于图IR比ONNX规范更为正交,因此许多ONNX操作被分解为单独的步骤,以下是一些示例

  • ONNX二进制操作隐式广播它们的操作数,但在IR中这是一个单独的操作。
  • ONNX卷积和矩阵乘法有一个内置的可选偏置操作数;这也成为一个单独的广播加二元加法操作。

要确定ONNX操作是否受支持,请检查load.rsvisit_node函数顶层匹配语句的分支。许多常见操作已经实现,添加更多操作也不应该太困难。

有关典型图的较大示例,请参阅stable_diffusion_piece.svg,这是从稳定扩散模型的开头取出的一个小部分。

优化器

图可以选择性地由优化器进行优化。由于图是只添加的,因此返回一个新的图。

目前实现的一些优化是

  • 常量折叠
  • 将连续的仿射(偏置、缩放、批量归一化)操作融合为单个偏置+缩放操作。
  • 将连续的夹逼操作(relu、min、max)融合为单个min+max操作。
  • 强度降低:用除以常数的操作替换乘以倒数常数的操作。
  • 识别layernorm模板(reduce、subtract、power、reduce、divide)并将其替换为layernorm运算符。

CPU 执行器

最后,需要执行图。有一个简单的CPU执行器,它只是直接运行每个操作。这里没有尝试进行重大优化,除了使用BLAS例程进行矩阵乘法和im2col进行卷积。重要的是,此执行器尽可能简单,因为它作为检查GPU执行器正确性的单元测试的基准。

Cuda 执行器

运行这些图的第二种(也是更有用)方式是使用Cuda executor。这涉及到通过Cuda Planner运行图,它输出预定的Cuda操作调度,并分配必要的内存缓冲区。这一步被单独分离出来,这样昂贵的规划步骤只需要在每个网络架构中执行一次;然后,生成的计划可以在executor中多次重用。

规划器有以下主要职责

  • 确定张量的内存布局:步长和内存偏移量
    • 这隐式地处理了大多数reshape、broadcast、stride等操作。
    • 如果可能,还会重用缓冲区,以最大限度地减少总内存使用。这里有很大的改进空间;目前,这只是一个单遍算法。
  • 决定运行卷积和矩阵乘法所需的cuDNN/cuBLAS操作。如果可能,将这些操作组合在一起。以下是一些例子
    • cuDNN支持一个“卷积 + 剩余 + 偏置 + relu”操作
    • cuBLAS的矩阵乘法可以包括任意一个输入矩阵的转置,以及等效地通过交换输入来转置输出。
    • cuDNN和cuBLAS操作有时包括一个“标量”参数,该参数与某些操作数相乘
  • 使用基于NVRTC(运行时编译)autokernel框架编译剩余的标量和复合操作的定制内核。
    • autokernel处理的操作包括:标量操作、reduce、softmax、layernorm、gather。
    • 使用手工核模板,在编译前将张量形状、步长、标量操作等细节替换进去。
    • 这里发生了更多的操作融合
      • 多个标量操作被编译为单个内核
      • 常量标量被内联
      • 一些复合内核支持融合输入或输出标量操作

这种最终的运算融合可以非常显著,并节省大量从主内存到主内存的冗余传输。虽然可以通过手动为每种使用的操作组合编写内核来实现相同的表现,但组合爆炸及其相关的维护将是巨大的。

下面展示了一个具有一些手工澄清注释的生成标量内核的示例

示例标量autokernel用于残差 + 批归一化 + relu6
#include "util.cu"

// constants that got inserted into the template
// this scalar operation happens on a tensor of rank 4, with 7 operands
const int RANK = 4;
const int OPERANDS = 7;
const int STRIDES_DENSE[RANK] = {648, 81, 9, 1};
const int STRIDES[OPERANDS][RANK] = {
    // these are full input tensors with normal, dense strides
    {648, 81, 9, 1},
    {648, 81, 9, 1},
    // these values have zero strides for all axes except the channel one,
    //    so these are probably biases and scaling factors
    //    that are broadcast across the other axes
    {0, 1, 0, 0},
    {0, 1, 0, 0},
    {0, 1, 0, 0},
    {0, 1, 0, 0},
    // the output tensor is just another operand
    {648, 81, 9, 1}
};

// the template function, the body of which is generated at runtime
__device__ void operation(void *pointers[OPERANDS], int offsets[OPERANDS]) {
    // all input operand memory locations are cast to the right type
    float *x0 = &((float *) pointers[0])[offsets[0]];
    float *x1 = &((float *) pointers[1])[offsets[1]];
    float *x2 = &((float *) pointers[2])[offsets[2]];
    float *x3 = &((float *) pointers[3])[offsets[3]];
    float *x4 = &((float *) pointers[4])[offsets[4]];
    float *x5 = &((float *) pointers[5])[offsets[5]];
    float *x6 = &((float *) pointers[6])[offsets[6]];
    
    // input operands are loaded
    float y0 = *x0;
    float y1 = *x1;
    
    // this is probably a residual connection
    float y2 = y0 + y1;
    
    // these 4 steps look like they're implementing a batchnorm layer  
    float y3 = *x2;
    float y4 = y2 - y3;
    float y5 = *x3;
    float y6 = y4 / y5;
    float y7 = *x4;
    float y8 = y6 * y7;
    float y9 = *x5;
    float y10 = y8 + y9;
    
    // this implements a relu6 activation function
    float y11 = 6;
    float y12 = min(y10, y11);
    float y13 = (0.0);
    float y14 = max(y12, y13);
    
    // finally the output is stored
    *x6 = y14;
}

// the kernel main function is the same for all scalar kernels
__global__ void scalar_kernel(
        int batch_size,
        Array<void *, OPERANDS> pointers
) {
    KernelInfo info = kernel_info();
    int size = batch_size * STRIDES_DENSE[0];

    // the main loop, following https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
    for (int flat = info.global_thread_id; flat < size; flat += info.thread_count) {
        Array<int, OPERANDS> offsets = flat_index_to_offsets<RANK, OPERANDS>(flat, STRIDES_DENSE, STRIDES);
        operation(pointers.data, &offsets[0]);
    }
}

与其他 Crates 的比较

有关所有潜在替代方案的完整列表,请参阅Are We Learning Yet?

Rust 封装现有运行时

  • PyTorch包装器:tch
  • TensorFlow包装器:tensorflow
  • ONNXRuntime包装器:ort

优点

  • 广泛支持许多神经网络操作
  • 支持许多不同的后端(CPU、GPU(Nvidia + AMD)、TPU等)

缺点

  • 并非总是对加载ONNX文件提供很好的支持(ort在这方面做得很好,正如其名称所示)
  • 大型且有点黑盒的外部依赖
  • 在许多情况下,运算融合较少,尽管预计未来会得到改善

对于运算融合不是很重要的情况,性能应该与Kyanite大致相同;所有库大多使用相同的底层cuDNN和cuBLAS内核。

从头开始的 Rust 项目

  • tract:对ONNX规范的覆盖范围更大,但仅支持CPU推理

开发

在开发此crate时,为了更新ONNX proto,使用了prost-build crate。这要求已安装protoc,并且将PROTOC环境变量设置为指向可执行文件。有关更多详细信息,请参阅他们的安装说明(或构建脚本显示的错误消息,如果有任何错误)。

实际上更新proto定义,请将kn-graph/proto/onnx.proto3替换为较新版本,并运行cargo run --bin proto-to-rust。然后提交onnx.proto3文件和生成的onnx.rs文件。

依赖项

~9MB
~168K SLoC