5个不稳定版本
0.3.0 | 2024年7月24日 |
---|---|
0.2.0 | 2024年7月21日 |
0.1.3 | 2024年7月17日 |
0.1.2 | 2024年7月16日 |
0.1.1 | 2024年7月14日 |
#160 在 数据结构
每月483次下载
270KB
8K SLoC
blas-array2
使用 ndarray::Array
(Ix1
或 Ix2
) 实现的Rust中的可选参数BLAS包装器。
现在风对我 步长 (主对角线) 产生了影响
我正在所有 方向 (代码'L' / 'R') 上失去阵地
--- 暗黑太阳...,动画《PERSONA5》的OP2
其他文档
- 开发文档(github链接:github,docs.rs链接:docs.rs)
- BLAS包装器结构列表(github链接:github,docs.rs链接:docs.rs)
- 效率演示(github链接:github,docs.rs链接:docs.rs)
从v0.2开始,这个crate实现了大部分计划的功能。这个crate被认为几乎完成,可能不会积极维护或更新。然而,我们也欢迎问题和PR来进一步增加新功能或修复错误。
简单案例示例
为了简单说明这个包,我们执行 $\mathbf{C} = \mathbf{A} \mathbf{B}$ (代码dgemm)
use blas_array2::prelude::*;
use ndarray::prelude::*;
let a = array![[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]];
let b = array![[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0]];
let c_out = DGEMM::default()
.a(a.view())
.b(b.view())
.run().unwrap()
.into_owned();
println!("{:7.3?}", c_out);
重要点是
- 使用
::default()
初始化结构体; .a
,.b
是设置函数;.run().unwrap()
将执行计算;.into_owned()
将返回结果矩阵,类型为Array2<f64>
。
功能
核心功能
- BLAS2/BLAS3 功能:所有(遗留)BLAS2/BLAS3 函数都已实现。
- 可选参数:遵循类似于BLAST Fortran 95 绑定或
scipy.linalg.blas
的约定。矩阵形状和主维度信息将得到检查和正确解析,因此用户无需提供这些值。 - 行主序布局:支持Fortran 77 API的行主序(没有CBLAS函数的CBLAS功能)。例如,可以使用随debian的
blas-sys
提供的默认libopenblas.so
,其中CBLAS未自动集成。 - 泛型:例如,为
f32
、f64
、Complex<f32>
、Complex<f64>
类型,在一个泛型(模板)类中实现了GEMM<F> where F: GEMMNum
。对于SYRK
或GEMV
等也是一样。原始名称如DGEMM
、ZSYR2K
也是可用的。 - 尽可能避免显式复制:所有行主序(或列主序)的输入都不应涉及不必要的转置和显式复制。此外,对于一些BLAS3函数(GEMM、SYRK、TRMM、TRSM),如果转置不涉及
BLASConjTrans
,则混合行主序或列主序也不涉及显式转置。另外,请注意,在许多情况下,子矩阵(切片矩阵)也被视为行主序(或列主序),如果数据在任何维度上连续存储。
其他功能
- 任意布局:支持
ndarray
允许的任何步长。 - FFI:目前,此crate使用其自定义FFI绑定
blas_array2::ffi::blas
作为BLAS绑定,类似于blas-sys。此外,此crate计划(或已)支持一些BLAS扩展和ILP64(通过cargo功能)。
Cargo功能
no_std
:禁用crate功能std
将与#![no_std]
兼容。然而,目前这些no_std
功能将需要alloc
。ilp64
:默认情况下,FFI绑定是LP64(32位整数)。crate功能ilp64
将启用ILP64(64位整数)。- BLAS扩展:一些crate功能将启用BLAS扩展。
gemmt
:GEMMTR(三角输出矩阵乘法)。对于OpenBLAS,需要版本0.3.27(0.3.26将失败一些测试)。
warn_on_copy
:如果输入矩阵布局不一致,且需要显式内存复制/转置/复共轭,则将在stderr上打印警告消息。error_on_copy
:类似于warn_on_copy
,但将直接引发BLASError
。
复杂情况的示例
对于复杂的情况,我们通过SGEMM = GEMM<f32>
执行$\mathbf{C} = \mathbf{A} \mathbf{B}^\mathrm{T}$。
use blas_array2::prelude::*;
use ndarray::prelude::*;
let a = array![[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]];
let b = array![[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0]];
let mut c = Array::ones((3, 3).f());
let c_out = GEMM::<f32>::default()
.a(a.slice(s![.., ..2]))
.b(b.view())
.c(c.slice_mut(s![0..3;2, ..]))
.transb('T')
.beta(1.5)
.run()
.unwrap();
// one can get the result as an owned array
// but the result may not refer to the same memory location as `c`
println!("{:4.3?}", c_out.into_owned());
// this modification on `c` is actually performed in-place
// so if `c` is pre-defined, not calling `into_owned` could be more efficient
println!("{:4.3?}", c);
重要点是
.c
是(可选)输出设置器,它消耗ArrayViewMut2<f64>
;此矩阵将就地修改。.transb
和.beta
是可选设置器;默认的transb
是'N'
,而默认的beta
是零,这与scipy对BLAS的Python接口的约定相同。您可以通过将值输入到可选设置器中来更改这些默认值。- 有三种方法可以使用输出
c_out.into_owned()
返回输出(如果c
在传递到设置器时被切片,则为子矩阵)作为Array2<f64>
。请注意,此输出与mut c
不共享相同的内存地址。c_out.view()
或c_out.view_mut()
返回对c
的视图;这些视图与mut c
共享相同的内存地址。- 或者您可以直接使用
c
。如果输出矩阵c
给出,则DGEMM操作将就地执行。
为了使上述代码清晰,此代码spinnet执行就地矩阵乘法
c = alpha * a * transpose(b) + beta * c
where
alpha = 1.0 (by default)
beta = 1.5
a = [[1.0, 2.0, ___],
[3.0, 4.0, ___]]
(sliced by `s![.., ..2]`)
b = [[-1.0, -2.0],
[-3.0, -4.0],
[-5.0, -6.0]]
c = [[1.0, 1.0, 1.0],
[___, ___, ___],
[1.0, 1.0, 1.0]]
(Column-major, sliced by `s![0..3;2, ..]`)
c
的输出是
[[-3.500, -9.500, -15.500],
[ 1.000, 1.000, 1.000],
[-9.500, -23.500, -37.500]]
泛型的示例
从v0.3开始,此crate现在支持(某种程度上)简单的泛型使用。例如GEMM和TRMM的示例,
use blas_array2::prelude::*;
use ndarray::prelude::*;
fn demo<F>()
where
F: GEMMNum + TRMMNum,
{
let a = Array2::<F>::ones((3, 3));
let b = Array2::<F>::ones((3, 3));
let mut c = GEMM::<F>::default().a(a.view()).b(b.view()).run().unwrap().into_owned();
TRMM::<F>::default().a(a.view()).b(c.view_mut()).run().unwrap();
println!("{:}", c);
}
fn main() {
demo::<f64>();
demo::<c64>();
}
这将给出以下结果
[[9, 9, 9],
[6, 6, 6],
[3, 3, 3]]
[[9+0i, 9+0i, 9+0i],
[6+0i, 6+0i, 6+0i],
[3+0i, 3+0i, 3+0i]]
安装
此crate在crates.io上可用。
如果在编译过程中遇到任何困难,请检查BLAS库是否正确链接。可能是通过声明
RUSTFLAGS="-lopenblas"
如果使用OpenBLAS作为后端。
某些功能(如ilp64
、gemmt
)需要BLAS以64位整数编译,或者某些BLAS扩展。
致谢
此项目是作为从REST的副项目开发的。
依赖关系
~2.5MB
~53K SLoC