12个版本
0.0.12 | 2022年7月24日 |
---|---|
0.0.11 | 2022年7月5日 |
0.0.10 | 2022年6月19日 |
0.0.8 | 2022年4月3日 |
0.0.1 | 2021年7月19日 |
#921 in 机器学习
2.5MB
70K SLoC
ors - Rust语言的onnxruntime绑定库
本项目提供了微软的 onnxruntime 的Rust绑定库,它是一个机器学习推理和训练框架。
警告:本项目处于非常早期阶段,尚未完成。据我所知,仍然存在许多错误。请勿在生产环境中使用。
先决条件
此crate需要您系统中有onnxruntime的C库版本v1.12.0。您可以使用 initialize_runtime()
读取C库
use ors::api::initialize_runtime;
use std::path::Path;
fn setup_runtime() {
#[cfg(target_os = "windows")]
let path = "/path/to/onnxruntime.dll";
#[cfg(target_os = "macos")]
let path = "/path/to/libonnxruntime.1.12.0.dylib";
#[cfg(target_os = "linux")]
let path = "/path/to/libonnxruntime.so";
initialize_runtime(Path::new(path)).unwrap();
}
示例
首先,将此crate添加到您的 cargo.toml
ors = "0.0.12"
此crate提供了 SessionBuilder
,它可以帮助您创建推理会话。您不需要创建onnxruntime推理环境,这由此crate处理
use ors::{
config::SessionGraphOptimizationLevel,
session::{SessionBuilder, run},
tensor::{create_tensor_with_ndarray, Tensor},
}
setup_runtime();
let session_builder = SessionBuilder::new().unwrap();
// Create an inference session from a model
let mut session = session_builder
.graph_optimization_level(SessionGraphOptimizationLevel::All)
.unwrap()
// Model conversion script can be found here: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb
.build_with_model_from_file("./gpt2.onnx")
.unwrap();
从 ndarray::ArrayD
创建张量并将其添加到模型输入
// Suppose that input_ids, position_ids and attention_mask are all ndarray::ArrayD
let mut inputs: Vec<Tensor> = vec![];
let input_ids_tensor = create_tensor_with_ndarray::<i64>(input_ids).unwrap();
let position_ids_tensor = create_tensor_with_ndarray::<i64>(positions_ids).unwrap();
let attention_mask_tensor = create_tensor_with_ndarray::<f32>(attension_mask).unwrap();
inputs.push(input_ids_tensor);
inputs.push(position_ids_tensor);
inputs.push(attention_mask_tensor);
// Add other inputs
// ...
对模型输出执行相同操作
let mut outputs: Vec<Tensor> = vec![];
// You should specify the output shape when creating the ndarray
let mut logits = ArrayD::<f32>::from_shape_vec(IxDyn(&[2, 9, 50257]), vec![0.0; 2 * 9 * 50257]).unwrap();
// Create tensor from logits and add it to output
let logits_tensor = create_tensor_with_ndarray::<f32>(logits).unwrap();
outputs.push(logits_tensor);
// Add other outputs
// ...
运行推理会话,模型的输出将写入 ndarray::ArrayD
,这些用于创建输出张量。
run(&mut session, &inputs, &mut outputs);
// Check the result
println!("inference result: logits: {:?}", outputs[0]);
输出
inference result: logits: [[[-15.88228, -15.500423, -17.979624, -18.302347, -17.527521, ..., -23.000717, -23.806093, -22.637945, -22.227428, -15.411578],
...
[-89.78022, -89.84351, -94.203995, -95.20875, -96.05158, ..., -101.95325, -103.50048, -101.1202, -98.740845, -90.956375]],
[[-33.367958, -32.94488, -36.2036, -36.568382, -35.434883, ..., -41.491924, -42.189476, -42.094162, -40.86978, -33.79733],
...
[-101.5143, -101.56593, -103.117065, -105.66759, -104.360954, ..., -104.53616, -107.3546, -109.82067, -110.87442, -101.61766]]], shape=[2, 9, 50257], strides=[452313, 50257, 1], layout=Cc (0x5), dynamic ndim=3
鸣谢
本项目最初是基于 onnxruntime-rs 的分支。大量代码来自onnxruntime-rs。感谢nbigaouette的出色工作。
许可证
本项目受以下其中之一许可协议的许可:
Apache许可证,版本2.0,(LICENSE-APACHE 或 https://apache.ac.cn/licenses/LICENSE-2.0)MIT许可证(LICENSE-MIT 或 http://opensource.org/licenses/MIT),由您选择。
依赖项
~2.9–5MB
~97K SLoC