#矩阵乘法 #矩阵 #线性代数 #无分配 #gemm

无std microgemm

使用Rust进行通用矩阵乘法,具有自定义配置。支持无std和无分配环境

3个版本 (重大变更)

0.3.1 2024年8月24日
0.3.0 2024年8月16日
0.2.1 2024年8月24日
0.2.0 2024年3月4日
0.1.2 2023年10月30日

数学类别中排名第149位

Download history 4/week @ 2024-07-05 6/week @ 2024-07-26 1/week @ 2024-08-02 146/week @ 2024-08-16

每月下载量约为153次

MIT/Apache

220KB
2K SLoC

αAB + βC

microgemm

github latest_version docs.rs dependency status

使用Rust进行通用矩阵乘法,具有自定义配置。
支持no_stdno_alloc环境。

实现基于BLIS微内核方法。

内容

安装

cargo add microgemm

使用方法

Kernel trait是microgemm的主要抽象。您可以自行实现它或使用已提供的核心。

gemm

use microgemm::{kernels::GenericKernel8x8, Kernel as _, MatMut, MatRef, PackSizes};

fn main() {
    let kernel = GenericKernel8x8::<f32>::new();
    assert_eq!(kernel.mr(), 8);
    assert_eq!(kernel.nr(), 8);

    let pack_sizes = PackSizes {
        mc: 5 * kernel.mr(), // MC must be divisible by MR
        kc: 190,
        nc: 9 * kernel.nr(), // NC must be divisible by NR
    };
    let mut packing_buf = vec![0.0; pack_sizes.buf_len()];

    let (alpha, beta) = (2.0, -3.0);
    let (m, k, n) = (100, 380, 250);

    let a = vec![2.0; m * k];
    let b = vec![3.0; k * n];
    let mut c = vec![4.0; m * n];

    let a = MatRef::row_major(m, k, &a);
    let b = MatRef::row_major(k, n, &b);
    let mut c = MatMut::row_major(m, n, &mut c);

    // c <- alpha a b + beta c
    kernel.gemm(alpha, a, b, beta, &mut c, pack_sizes, &mut packing_buf);
    println!("{:?}", c.as_slice());
}

还可以参考无分配示例,用于不使用Vec的情况。

实现的核心

名称 标量类型 目标
GenericKernelNxN
(N: 2, 4, 8, 16, 32)
T: Copy + Zero + One + Mul + Add Any
NeonKernel4x4 f32 aarch64和具有neon特性的目标
NeonKernel8x8 f32 aarch64和具有neon特性的目标

自定义核心实现

use microgemm::{typenum::U4, Kernel, MatMut, MatRef};

struct CustomKernel;

impl Kernel for CustomKernel {
    type Scalar = f64;
    type Mr = U4;
    type Nr = U4;

    // dst <- alpha lhs rhs + beta dst
    fn microkernel(
        &self,
        alpha: f64,
        lhs: MatRef<f64>,
        rhs: MatRef<f64>,
        beta: f64,
        dst: &mut MatMut<f64>,
    ) {
        // lhs is col-major
        assert_eq!(lhs.row_stride(), 1);
        assert_eq!(lhs.nrows(), Self::MR);

        // rhs is row-major
        assert_eq!(rhs.col_stride(), 1);
        assert_eq!(rhs.ncols(), Self::NR);

        // dst is col-major
        assert_eq!(dst.row_stride(), 1);
        assert_eq!(dst.nrows(), Self::MR);
        assert_eq!(dst.ncols(), Self::NR);

        // your microkernel implementation...
    }
}

基准测试

所有基准测试都在单线程上对维度为n的正方形矩阵进行。

f32

打包大小{mc:n,kc:n,nc:n}

aarch64 (M1)

   n  NeonKernel8x8           faer matrixmultiply
 128         75.5µs        242.6µs         46.2µs
 256        466.3µs          3.2ms        518.2µs
 512            3ms         15.9ms          2.7ms
1024         23.9ms        128.4ms           22ms
2048          191ms             1s        182.8ms

许可证

根据您的选择,许可协议为Apache License,版本2.0MIT许可

依赖项

~410KB