#blas #array #matrix

无需std blas-array2

可选参数BLAS包装器,由ndarray::Array (Ix1或Ix2) 实现

5个不稳定版本

0.3.0 2024年7月24日
0.2.0 2024年7月21日
0.1.3 2024年7月17日
0.1.2 2024年7月16日
0.1.1 2024年7月14日

#160数据结构

Download history • Rust 包仓库 213/week @ 2024-07-11 • Rust 包仓库 207/week @ 2024-07-18 • Rust 包仓库 58/week @ 2024-07-25 • Rust 包仓库 5/week @ 2024-08-01 • Rust 包仓库

每月483次下载

Apache-2.0

270KB
8K SLoC

blas-array2

codecov crates.io

使用 ndarray::Array (Ix1Ix2) 实现的Rust中的可选参数BLAS包装器。

现在风对我 步长 (主对角线) 产生了影响

我正在所有 方向 (代码'L' / 'R') 上失去阵地

--- 暗黑太阳...,动画《PERSONA5》的OP2

其他文档

  • 开发文档(github链接:github,docs.rs链接:docs.rs
  • BLAS包装器结构列表(github链接:github,docs.rs链接:docs.rs
  • 效率演示(github链接:github,docs.rs链接:docs.rs

从v0.2开始,这个crate实现了大部分计划的功能。这个crate被认为几乎完成,可能不会积极维护或更新。然而,我们也欢迎问题和PR来进一步增加新功能或修复错误。

简单案例示例

为了简单说明这个包,我们执行 $\mathbf{C} = \mathbf{A} \mathbf{B}$ (代码dgemm

use blas_array2::prelude::*;
use ndarray::prelude::*;
let a = array![[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]];
let b = array![[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0]];
let c_out = DGEMM::default()
    .a(a.view())
    .b(b.view())
    .run().unwrap()
    .into_owned();
println!("{:7.3?}", c_out);

重要点是

  • 使用 ::default() 初始化结构体;
  • .a.b 是设置函数;
  • .run().unwrap() 将执行计算;
  • .into_owned()将返回结果矩阵,类型为Array2<f64>

功能

核心功能

  • BLAS2/BLAS3 功能:所有(遗留)BLAS2/BLAS3 函数都已实现。
  • 可选参数:遵循类似于BLAST Fortran 95 绑定scipy.linalg.blas的约定。矩阵形状和主维度信息将得到检查和正确解析,因此用户无需提供这些值。
  • 行主序布局:支持Fortran 77 API的行主序(没有CBLAS函数的CBLAS功能)。例如,可以使用随debian的blas-sys提供的默认libopenblas.so,其中CBLAS未自动集成。
  • 泛型:例如,为f32f64Complex<f32>Complex<f64>类型,在一个泛型(模板)类中实现了GEMM<F> where F: GEMMNum。对于SYRKGEMV等也是一样。原始名称如DGEMMZSYR2K也是可用的。
  • 尽可能避免显式复制:所有行主序(或列主序)的输入都不应涉及不必要的转置和显式复制。此外,对于一些BLAS3函数(GEMM、SYRK、TRMM、TRSM),如果转置不涉及BLASConjTrans,则混合行主序或列主序也不涉及显式转置。另外,请注意,在许多情况下,子矩阵(切片矩阵)也被视为行主序(或列主序),如果数据在任何维度上连续存储。

其他功能

  • 任意布局:支持ndarray允许的任何步长。
  • FFI:目前,此crate使用其自定义FFI绑定blas_array2::ffi::blas作为BLAS绑定,类似于blas-sys。此外,此crate计划(或已)支持一些BLAS扩展和ILP64(通过cargo功能)。

Cargo功能

  • no_std:禁用crate功能std将与#![no_std]兼容。然而,目前这些no_std功能将需要alloc
  • ilp64:默认情况下,FFI绑定是LP64(32位整数)。crate功能ilp64将启用ILP64(64位整数)。
  • BLAS扩展:一些crate功能将启用BLAS扩展。
    • gemmt:GEMMTR(三角输出矩阵乘法)。对于OpenBLAS,需要版本0.3.27(0.3.26将失败一些测试)。
  • warn_on_copy:如果输入矩阵布局不一致,且需要显式内存复制/转置/复共轭,则将在stderr上打印警告消息。
  • error_on_copy:类似于warn_on_copy,但将直接引发BLASError

复杂情况的示例

对于复杂的情况,我们通过SGEMM = GEMM<f32>执行$\mathbf{C} = \mathbf{A} \mathbf{B}^\mathrm{T}$。

use blas_array2::prelude::*;
use ndarray::prelude::*;

let a = array![[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]];
let b = array![[-1.0, -2.0], [-3.0, -4.0], [-5.0, -6.0]];
let mut c = Array::ones((3, 3).f());

let c_out = GEMM::<f32>::default()
    .a(a.slice(s![.., ..2]))
    .b(b.view())
    .c(c.slice_mut(s![0..3;2, ..]))
    .transb('T')
    .beta(1.5)
    .run()
    .unwrap();
// one can get the result as an owned array
// but the result may not refer to the same memory location as `c`
println!("{:4.3?}", c_out.into_owned());
// this modification on `c` is actually performed in-place
// so if `c` is pre-defined, not calling `into_owned` could be more efficient
println!("{:4.3?}", c);

重要点是

  • .c是(可选)输出设置器,它消耗ArrayViewMut2<f64>;此矩阵将就地修改。
  • .transb.beta是可选设置器;默认的transb'N',而默认的beta是零,这与scipy对BLAS的Python接口的约定相同。您可以通过将值输入到可选设置器中来更改这些默认值。
  • 有三种方法可以使用输出
    • c_out.into_owned()返回输出(如果c在传递到设置器时被切片,则为子矩阵)作为Array2<f64>。请注意,此输出与mut c不共享相同的内存地址。
    • c_out.view()c_out.view_mut()返回对c的视图;这些视图与mut c共享相同的内存地址。
    • 或者您可以直接使用c。如果输出矩阵c给出,则DGEMM操作将就地执行。

为了使上述代码清晰,此代码spinnet执行就地矩阵乘法

c = alpha * a * transpose(b) + beta * c
where
alpha = 1.0 (by default)
beta = 1.5
a = [[1.0, 2.0, ___],
     [3.0, 4.0, ___]]
        (sliced by `s![.., ..2]`)
b = [[-1.0, -2.0],
     [-3.0, -4.0],
     [-5.0, -6.0]]
c = [[1.0, 1.0, 1.0],
     [___, ___, ___],
     [1.0, 1.0, 1.0]]
        (Column-major, sliced by `s![0..3;2, ..]`)

c的输出是

[[-3.500,  -9.500, -15.500],
 [ 1.000,   1.000,   1.000],
 [-9.500, -23.500, -37.500]]

泛型的示例

从v0.3开始,此crate现在支持(某种程度上)简单的泛型使用。例如GEMM和TRMM的示例,

use blas_array2::prelude::*;
use ndarray::prelude::*;

fn demo<F>()
where
    F: GEMMNum + TRMMNum,
{
    let a = Array2::<F>::ones((3, 3));
    let b = Array2::<F>::ones((3, 3));
    let mut c = GEMM::<F>::default().a(a.view()).b(b.view()).run().unwrap().into_owned();
    TRMM::<F>::default().a(a.view()).b(c.view_mut()).run().unwrap();
    println!("{:}", c);
}

fn main() {
    demo::<f64>();
    demo::<c64>();
}

这将给出以下结果

[[9, 9, 9],
 [6, 6, 6],
 [3, 3, 3]]
[[9+0i, 9+0i, 9+0i],
 [6+0i, 6+0i, 6+0i],
 [3+0i, 3+0i, 3+0i]]

安装

此crate在crates.io上可用。

如果在编译过程中遇到任何困难,请检查BLAS库是否正确链接。可能是通过声明

RUSTFLAGS="-lopenblas"

如果使用OpenBLAS作为后端。

某些功能(如ilp64gemmt)需要BLAS以64位整数编译,或者某些BLAS扩展。

致谢

此项目是作为从REST的副项目开发的。

依赖关系

~2.5MB
~53K SLoC