2 个版本
0.1.1 | 2023 年 11 月 22 日 |
---|---|
0.1.0 | 2023 年 11 月 22 日 |
#338 在 机器学习 中
17KB
162 行
ort-batcher
使用 ort 批处理 ONNX 模型推理的小型 crate。受 batched_fn 启发。
请注意,它只能与以下模型一起使用
- 其第一个维度是动态的(-1),因此可以批处理。
- 输入和输出都是类型为
float32
的张量。
使用方法
let max_batch_size = 32;
let max_wait_time = Duration::from_millis(80);
let batcher = Batcher::spawn(session, max_batch_size, max_wait_time);
// in some thread
let inputs = vec![ArrayD::<f32>::zeros(vec![7, 8, 9])];
let outputs = batcher.run(inputs).unwrap();
示例
查看 example.rs
use ndarray::{ArrayD, Axis};
use ort::{CUDAExecutionProvider, Environment, SessionBuilder, Value};
use ort_batcher::batcher::Batcher;
use std::time::Duration;
fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();
ort::init()
.with_execution_providers([CUDAExecutionProvider::default().build()])
.commit()?;
let session = Session::builder()?
.with_intra_threads(1)?
.with_model_from_memory(include_bytes!("../tests/model.onnx"))?;
{
let start = std::time::Instant::now();
// 128 threads
// 256 inferences each
// sequential
std::thread::scope(|s| {
for _ in 0..128 {
let session = &session;
let input = ArrayD::<f32>::zeros(vec![7, 8, 9]);
s.spawn(move || {
for _ in 0..256 {
let value = Value::from_array(input.clone().insert_axis(Axis(0))).unwrap();
let _output = session.run([value]).unwrap()[0]
.extract_tensor::<f32>()
.unwrap()
.view()
.index_axis(Axis(0), 0)
.to_owned();
}
});
}
});
println!("sequential: {:?}", start.elapsed());
}
let max_batch_size = 32;
let max_wait_time = Duration::from_millis(10);
let batcher = Batcher::spawn(session, max_batch_size, max_wait_time);
{
let start = std::time::Instant::now();
// 128 threads
// 256 inferences each
// batched
std::thread::scope(|s| {
for _ in 0..128 {
let batcher = &batcher;
let input = ArrayD::<f32>::zeros(vec![7, 8, 9]);
s.spawn(move || {
for _ in 0..256 {
let _output = batcher.run(vec![input.clone()]).unwrap();
}
});
}
});
println!("batched: {:?}", start.elapsed());
}
Ok(())
}
请注意,为了获得良好的结果,您必须在 GPU 上使用重量级模型,否则您可能看不到任何区别。
依赖关系
~2.7–4.5MB
~87K SLoC