3个版本 (重大变更)
新 0.3.1 | 2024年8月24日 |
---|---|
0.3.0 |
|
0.2.1 | 2024年8月24日 |
0.2.0 |
|
0.1.2 |
|
在数学类别中排名第149位
每月下载量约为153次
220KB
2K SLoC
microgemm
使用Rust进行通用矩阵乘法,具有自定义配置。
支持no_std
和no_alloc
环境。
实现基于BLIS微内核方法。
内容
安装
cargo add microgemm
使用方法
Kernel
trait是microgemm
的主要抽象。您可以自行实现它或使用已提供的核心。
gemm
use microgemm::{kernels::GenericKernel8x8, Kernel as _, MatMut, MatRef, PackSizes};
fn main() {
let kernel = GenericKernel8x8::<f32>::new();
assert_eq!(kernel.mr(), 8);
assert_eq!(kernel.nr(), 8);
let pack_sizes = PackSizes {
mc: 5 * kernel.mr(), // MC must be divisible by MR
kc: 190,
nc: 9 * kernel.nr(), // NC must be divisible by NR
};
let mut packing_buf = vec![0.0; pack_sizes.buf_len()];
let (alpha, beta) = (2.0, -3.0);
let (m, k, n) = (100, 380, 250);
let a = vec![2.0; m * k];
let b = vec![3.0; k * n];
let mut c = vec![4.0; m * n];
let a = MatRef::row_major(m, k, &a);
let b = MatRef::row_major(k, n, &b);
let mut c = MatMut::row_major(m, n, &mut c);
// c <- alpha a b + beta c
kernel.gemm(alpha, a, b, beta, &mut c, pack_sizes, &mut packing_buf);
println!("{:?}", c.as_slice());
}
还可以参考无分配示例,用于不使用Vec
的情况。
实现的核心
名称 | 标量类型 | 目标 |
---|---|---|
GenericKernelNxN (N: 2, 4, 8, 16, 32) |
T: Copy + Zero + One + Mul + Add | Any |
NeonKernel4x4 | f32 | aarch64和具有neon特性的目标 |
NeonKernel8x8 | f32 | aarch64和具有neon特性的目标 |
自定义核心实现
use microgemm::{typenum::U4, Kernel, MatMut, MatRef};
struct CustomKernel;
impl Kernel for CustomKernel {
type Scalar = f64;
type Mr = U4;
type Nr = U4;
// dst <- alpha lhs rhs + beta dst
fn microkernel(
&self,
alpha: f64,
lhs: MatRef<f64>,
rhs: MatRef<f64>,
beta: f64,
dst: &mut MatMut<f64>,
) {
// lhs is col-major
assert_eq!(lhs.row_stride(), 1);
assert_eq!(lhs.nrows(), Self::MR);
// rhs is row-major
assert_eq!(rhs.col_stride(), 1);
assert_eq!(rhs.ncols(), Self::NR);
// dst is col-major
assert_eq!(dst.row_stride(), 1);
assert_eq!(dst.nrows(), Self::MR);
assert_eq!(dst.ncols(), Self::NR);
// your microkernel implementation...
}
}
基准测试
所有基准测试都在单线程上对维度为n
的正方形矩阵进行。
f32
打包大小{mc:n,kc:n,nc:n}
aarch64 (M1)
n NeonKernel8x8 faer matrixmultiply
128 75.5µs 242.6µs 46.2µs
256 466.3µs 3.2ms 518.2µs
512 3ms 15.9ms 2.7ms
1024 23.9ms 128.4ms 22ms
2048 191ms 1s 182.8ms
许可证
根据您的选择,许可协议为Apache License,版本2.0或MIT许可。
依赖项
~410KB