4个版本
0.3.0 | 2023年7月3日 |
---|---|
0.2.2 | 2021年10月19日 |
0.2.1 | 2021年10月19日 |
0.2.0 | 2021年5月19日 |
#359 在 WebAssembly
每月 30 次下载
36KB
836 行
WasmEdge Tensorflow Interface
一个Rust库,为在WasmEdge上执行Wasm时使用TensorFlow功能的Rust到WebAssembly的开发者提供语法。
从高层次概述来看,我们实际上正在构建一个tensorflow接口,这将允许本地操作系统(WasmEdge在其中运行)在运行时执行中发挥作用。具体来说,在Wasm执行的部分,使用图和输入输出张量来推断TensorFlow或TensorFlow-Lite。
如何使用此库
Rust依赖
开发者将在他们的Rust -> Wasm
应用程序中将wasmedge_tensorflow_interface
crate作为依赖项添加。例如,将以下行添加到应用程序的Cargo.toml
文件中。
[dependencies]
wasmedge_tensorflow_interface = "0.3.0"
开发者将把wasmedge_tensorflow_interface
的功能引入到他们的Rust -> Wasm
应用程序代码中。例如,将以下代码添加到他们的main.rs
顶部。
use wasmedge_tensorflow_interface;
图像加载和转换
在这个crate中,我们通过使用WasmEdge-Image
宿主函数提供几个将图像解码和转换为张量的函数。
对于解码JPEG图像,有以下函数:
// Function to decode JPEG from buffer and resize to RGB8 format.
pub fn load_jpg_image_to_rgb8(img_buf: &[u8], w: u32, h: u32) -> Vec<u8>
// Function to decode JPEG from buffer and resize to BGR8 format.
pub fn load_jpg_image_to_bgr8(img_buf: &[u8], w: u32, h: u32) -> Vec<u8>
// Function to decode JPEG from buffer and resize to RGB32F format.
pub fn load_jpg_image_to_rgb32f(img_buf: &[u8], w: u32, h: u32) -> Vec<f32>
// Function to decode JPEG from buffer and resize to BGR32F format.
pub fn load_jpg_image_to_rgb32f(img_buf: &[u8], w: u32, h: u32) -> Vec<f32>
对于解码PNG图像,有以下函数:
// Function to decode PNG from buffer and resize to RGB8 format.
pub fn load_png_image_to_rgb8(img_buf: &[u8], w: u32, h: u32) -> Vec<u8>
// Function to decode PNG from buffer and resize to BGR8 format.
pub fn load_png_image_to_bgr8(img_buf: &[u8], w: u32, h: u32) -> Vec<u8>
// Function to decode PNG from buffer and resize to RGB32F format.
pub fn load_png_image_to_rgb32f(img_buf: &[u8], w: u32, h: u32) -> Vec<f32>
// Function to decode PNG from buffer and resize to BGR32F format.
pub fn load_png_image_to_rgb32f(img_buf: &[u8], w: u32, h: u32) -> Vec<f32>
开发者可以按照以下方式加载、解码和调整图像大小:
let mut file_img = File::open("sample.jpg").unwrap();
let mut img_buf = Vec::new();
file_img.read_to_end(&mut img_buf).unwrap();
let flat_img = wasmedge_tensorflow_interface::load_jpg_image_to_rgb32f(&img_buf, 224, 224);
// The flat_img is a vec<f32> which contains normalized image in rgb32f format and resized to 224x224.
要使用上述函数在WASM中执行并在WasmEdge中运行,用户应安装WasmEdge-Image插件。
推断TensorFlow和TensorFlow-Lite模型
创建会话
首先,开发者应该创建一个会话来加载TensorFlow或TensorFlow-Lite模型。
// The mod_buf is a vec<u8> which contains the model data.
let mut session = wasmedge_tensorflow_interface::TFSession::new(&mod_buf);
上述函数是创建TensorFlow冻结模型的会话。开发者可以使用new_from_saved_model
函数从保存的模型中创建。
// The mod_path is a &str which is the path to saved-model directory.
// The second argument is the list of tags.
let mut session = wasmedge_tensorflow_interface::TFSession::new_from_saved_model(model_path, &["serve"]);
或者使用 TFLiteSession
来创建用于推断 tflite
模型的会话。
// The mod_buf is a vec<u8> which contains the model data.
let mut session = wasmedge_tensorflow_interface::TFLiteSession::new(&mod_buf);
要使用 TFSession
结构并在 WasmEdge 中执行,用户应安装 WasmEdge-TensorFlow 插件及其依赖项。
要使用 TFLiteSession
结构并在 WasmEdge 中执行,用户应安装 WasmEdge-TensorFlowLite 插件及其依赖项。
准备输入张量
// The flat_img is a vec<f32> which contains normalized image in rgb32f format.
session.add_input("input", &flat_img, &[1, 224, 224, 3])
.add_output("MobilenetV2/Predictions/Softmax");
运行 TensorFlow 模型
session.run();
转换输出张量
let res_vec: Vec<f32> = session.get_output("MobilenetV2/Predictions/Softmax");
构建和执行
cargo build --target=wasm32-wasi
输出的 WASM 文件将在 target/wasm32-wasi/debug/
或 target/wasm32-wasi/release
。
请参考 WasmEdge 安装 来安装带有必要插件的 WasmEdge,以及 WasmEdge CLI WASM 执行。
Crate.io
官方 crate 可在 crates.io 获取。