3个版本
0.1.2 | 2023年4月1日 |
---|---|
0.1.1 | 2023年3月29日 |
0.1.0 | 2023年3月15日 |
#483 在 机器学习
每月54次下载
355KB
8K SLoC
Edge transformers是基于ONNX Runtime后端的Huggingface pipelines 的Rust实现。
特性
- C#和C封装(计划实现合适的C++封装)
- 通过分词器抽象的文本到输出的接口。
- 支持多个ORT提供者
- CPU
- CUDA(需要使用CUDA提供者构建onnxruntime)
- DirectML
- 更多计划中
实现的任务
模型导出功能/任务 | 类名 |
---|---|
causal-lm | ConditionalGenerationPipeline |
causal-lm-with-past | ConditionalGenerationPipelineWithPKVs |
default | EmbeddingPipeline |
seq2seq-lm | Seq2SeqGenerationPipeline或OptimumSeq2SeqPipeline |
seq2seq-lm-with-past | OptimumSeq2SeqPipelineWithPKVs |
sequence-classification | SequenceClassificationPipeline |
token-classification | TokenClassificationPipeline |
使用方法
您的链接器必须能够找到onnxruntime.dll和edge-transformers.dll(或Linux上的*.so)。您可以在分别的c
和csharp
文件夹中找到C和C#封装。
C#
文档正在建设中,目前请参考Rust文档。
using EdgeTransformers;
...
var env = EnvContainer.New();
var conditionalGen = ConditionalGenerationPipelineFFI.FromPretrained(
env.Context, "optimum/gpt2", DeviceFFI.CPU, GraphOptimizationLevelFFI.Level3
);
var outp = conditionalGen.GenerateTopkSampling("Hello", 2, 50, 0.9f);
Assert.IsNotNull(outp);
...
支持批量处理,但有点不直观,需要StringBatch类。
using EdgeTransformers;
...
var env = EnvContainer.New();
var condPipelinePkv = ConditionalGenerationPipelineFFI.FromPretrained(
env.Context, "optimum/gpt2", DeviceFFI.DML, GraphOptimizationLevelFFI.All);
var string_batch = StringBatch.New();
string_batch.Add("Hello world");
string_batch.Add("Hello world");
var o_batch_pkv = condPipelinePkv.GenerateRandomSamplingBatch(string_batch.Context, 10, 0.5f);
Debug.LogFormat("Cond generation output 0: {0} 1: {1}", o_batch_pkv[0].ascii_string, o_batch_pkv[1].ascii_string);
...
C
待办事项
Rust
use std::fs;
use ort::environment::Environment;
use ort::{GraphOptimizationLevel, LoggingLevel};
use edge_transformers::{ConditionalGenerationPipelineWithPKVs, TopKSampler, Device};
let environment = Environment::builder()
.with_name("test")
.with_log_level(LoggingLevel::Verbose)
.build()
.unwrap();
let sampler = TopKSampler::new(50, 0.9);
let pipeline = ConditionalGenerationPipelineWithPKVs::from_pretrained(
environment.into_arc(),
"optimum/gpt2".to_string(),
Device::CPU,
GraphOptimizationLevel::Level3,
).unwrap();
let input = "This is a test";
println!("{}", pipeline.generate(input, 10, &sampler).unwrap());
路线图
- C#封装
- C封装
- 迁移到ort crate
- 适当的CI/CD以测试和构建更多执行提供者
- C++封装
- 更多管道(例如抽取式QA、ASR等)
- 更好的huggingface config.json解析
构建
请参阅ONNX Runtime绑定文档以获取详细信息。
测试
测试至少第一次需要在单线程上运行。原因是它们使用 *::from_pretrained 函数从Huggingface Hub下载数据,并且一些测试依赖于相同文件被下载。第二次它们可以并行运行,因为它们使用缓存文件。
例如第一次命令
cargo test -- --test-threads=1
例如第二次
cargo test
依赖关系
~23–39MB
~684K SLoC