#卷积 #ndarray #fft #反射

程序+库 ndarray-conv

N 维卷积(带 FFT)库,用于 ndarray

3 个版本 (破坏性更新)

0.3.3 2024年4月8日
0.3.2 2024年4月8日
0.2.0 2023年3月28日
0.1.3 2022年12月23日

#117 in 数学

Download history 563/week @ 2024-04-28 422/week @ 2024-05-05 279/week @ 2024-05-12 660/week @ 2024-05-19 456/week @ 2024-05-26 682/week @ 2024-06-02 281/week @ 2024-06-09 626/week @ 2024-06-16 352/week @ 2024-06-23 387/week @ 2024-06-30 452/week @ 2024-07-07 141/week @ 2024-07-14 194/week @ 2024-07-21 379/week @ 2024-07-28 279/week @ 2024-08-04 502/week @ 2024-08-11

1,361 每月下载次数
proseg 中使用

MIT/Apache

66KB
1.5K SLoC

ndarray-conv

ndarray-conv 是一个提供纯 Rust 编写的 N 维卷积(带 FFT 加速)库的 crate。

受以下项目启发

ndarray-vision (https://github.com/rust-cv/ndarray-vision)

convolutions-rs (https://github.com/Conzel/convolutions-rs#readme)

pocketfft (https://github.com/mreineck/pocketfft)

路线图

  • N 维基本卷积 Array/ArrayView
  • N 维卷积带 FFT 加速 Array/ArrayView
  • 实现 ConvModePaddingMode
    • ConvMode: 全部 相同 有效 自定义 明确
    • PaddingMode: 零 常数 反射 复制 循环 自定义 明确
  • 带步长的卷积
  • 带膨胀的核
  • 处理输入大小错误
  • 明确的错误类型
  • 与类似库进行基准测试

示例

use ndarray_conv::*;

x_nd.conv(
    &k_n,
    PaddingSize::Full,
    PaddingMode::Circular,
);

x_1d.view().conv_fft(
    &k_1d,
    ConvMode::Same,
    PaddingMode::Explicit([[BorderType::Replicate, BorderType::Reflect]]),
);

x_2d.conv_fft(
    k_2d.with_dilation(2),
    PaddingSize::Same,
    PaddingMode::Custom([BorderType::Reflect, BorderType::Circular]),
);

// avoid loss of accuracy for fft ver
// convert Integer to Float before caculate.
x_3d.map(|&x| x as f32)
    .conv_fft(
        &kernel.map(|&x| x as f32),
        ConvMode::Same,
        PaddingMode::Zeros,
    )
    .unwrap()
    .map(|x| x.round() as i32);
fn main() {
    use ndarray_conv::*;
    use ndarray::prelude::*;
    use ndarray_rand::rand_distr::Uniform;
    use ndarray_rand::RandomExt;
    use std::time::Instant;

    let mut small_duration = 0u128;
    let test_cycles_small = 1;
    // small input data
    for _ in 0..test_cycles_small {
        let x = Array::random((2000, 4000), Uniform::new(0., 1.));
        let k = Array::random((9, 9), Uniform::new(0., 1.));

        let now = Instant::now();
        // or use x.conv_fft() for large input data
        x.conv(
            &k,
            ConvMode::Same,
            PaddingMode::Custom([BorderType::Reflect, BorderType::Circular]),
        );
        small_duration += now.elapsed().as_nanos();
    }

    println!(
        "Time for small arrays, {} iterations: {} milliseconds",
        test_cycles_small,
        small_duration / 1_000_000
    );
}

版本

  • 0.3.3 - 修复错误:纠正 conv_fft 的输出形状。
  • 0.3.2 - 通过修改 good_fft_sizetranspose 来提高性能。
  • 0.3.1 - 实现 basic error 类型。修复一些错误。
  • 0.3.0 - 更新为 N 维卷积。
  • 0.2.0 - 完成 conv_2dconv_2d_fft

依赖项

~7.5MB
~142K SLoC