#batch #onnx-runtime #ort

ort_batcher

使用 ort (onnxruntime) 批处理 ONNX 模型推理的小型 crate

2 个版本

0.1.1 2023 年 11 月 22 日
0.1.0 2023 年 11 月 22 日

#338机器学习

Apache-2.0

17KB
162

ort-batcher

Crates.io

使用 ort 批处理 ONNX 模型推理的小型 crate。受 batched_fn 启发。

请注意,它只能与以下模型一起使用

  • 其第一个维度是动态的(-1),因此可以批处理。
  • 输入和输出都是类型为 float32 的张量。

使用方法

let max_batch_size = 32;
let max_wait_time = Duration::from_millis(80);
let batcher = Batcher::spawn(session, max_batch_size, max_wait_time);

// in some thread
let inputs = vec![ArrayD::<f32>::zeros(vec![7, 8, 9])];
let outputs = batcher.run(inputs).unwrap();

示例

查看 example.rs

use ndarray::{ArrayD, Axis};
use ort::{CUDAExecutionProvider, Environment, SessionBuilder, Value};
use ort_batcher::batcher::Batcher;
use std::time::Duration;

fn main() -> ort::Result<()> {
    tracing_subscriber::fmt::init();

    ort::init()
        .with_execution_providers([CUDAExecutionProvider::default().build()])
        .commit()?;

    let session = Session::builder()?
        .with_intra_threads(1)?
        .with_model_from_memory(include_bytes!("../tests/model.onnx"))?;

    {
        let start = std::time::Instant::now();

        // 128 threads
        // 256 inferences each
        // sequential
        std::thread::scope(|s| {
            for _ in 0..128 {
                let session = &session;
                let input = ArrayD::<f32>::zeros(vec![7, 8, 9]);

                s.spawn(move || {
                    for _ in 0..256 {
                        let value = Value::from_array(input.clone().insert_axis(Axis(0))).unwrap();
                        let _output = session.run([value]).unwrap()[0]
                            .extract_tensor::<f32>()
                            .unwrap()
                            .view()
                            .index_axis(Axis(0), 0)
                            .to_owned();
                    }
                });
            }
        });

        println!("sequential: {:?}", start.elapsed());
    }

    let max_batch_size = 32;
    let max_wait_time = Duration::from_millis(10);
    let batcher = Batcher::spawn(session, max_batch_size, max_wait_time);

    {
        let start = std::time::Instant::now();

        // 128 threads
        // 256 inferences each
        // batched
        std::thread::scope(|s| {
            for _ in 0..128 {
                let batcher = &batcher;
                let input = ArrayD::<f32>::zeros(vec![7, 8, 9]);

                s.spawn(move || {
                    for _ in 0..256 {
                        let _output = batcher.run(vec![input.clone()]).unwrap();
                    }
                });
            }
        });

        println!("batched: {:?}", start.elapsed());
    }

    Ok(())
}

请注意,为了获得良好的结果,您必须在 GPU 上使用重量级模型,否则您可能看不到任何区别。

依赖关系

~2.7–4.5MB
~87K SLoC