1个不稳定版本
0.1.0 | 2023年2月4日 |
---|
#1582 在 编码
18KB
272 行
Banana Rust SDK
入门
通过cargo添加 cargo add banana-rust-sdk
获取您的API密钥
简单示例
以下两个示例都基于调用此模板中的模型链接。
注意,由于模型依赖于您的模型,Banana SDK无法检查您的模型输入是否正确。`banana_rust_sdk::run()
`接受任何有效的JSON(`serde_json::value
)作为模型输入。以下是一个具有类型检查的更详细的示例。
use banana_rust_sdk;
use serde::Serialize;
#[tokio::main]
async fn main() {
#[derive(Serialize)]
struct ModelInputs {
prompt: String
}
let api_key = "API_KEY";
let model_key = "MODEL_KEY";
let model_inputs = ModelInputs {
prompt: "try to predict the next [MASK] of this sentence.".to_string()
};
let model_inputs = serde_json::to_value(model_inputs).unwrap();
let res = banana_rust_sdk::run(api_key, model_key, model_inputs).await.unwrap();
let json = serde_json::to_value(res).unwrap();
println!("{:?}", json);
}
具有输入类型检查的示例
use banana_rust_sdk;
use serde::Serialize;
use serde::Deserialize;
use std::{error::Error, fmt};
#[derive(Debug)]
struct CustomError;
impl Error for CustomError {}
impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Oh no, something bad went down")
}
}
#[derive(Serialize)]
struct ModelInputs {
prompt: String
}
// Here we define the type of what the model should ouput
#[derive(Serialize, Deserialize, Debug)]
struct ResponseObject {
score: f64,
sequence: String,
token: usize,
token_str: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct ModelOutputs {
response_object: Vec<ResponseObject>
}
#[tokio::main]
async fn main() {
let api_key = "API_KEY";
let model_key = "MODEL_KEY";
let model_inputs = ModelInputs {
prompt: "try to predict the next [MASK] of this sentence.".to_string()
};
let model_inputs = serde_json::to_value(model_inputs).unwrap();
let model_ouputs = call_banana(api_key, model_key, model_inputs).await.unwrap();
// And now we can get e.g. the prediction with the highest score
let item = &model_ouputs.response_object[0];
let seq = &item.sequence;
println!("{:?}", seq);
}
async fn call_banana(api_key: &str, model_key: &str, model_inputs: serde_json::Value) -> Result<ModelOutputs, CustomError> {
match banana_rust_sdk::run(api_key, model_key, model_inputs).await {
Ok(res) => {
match res.model_outputs {
Some(value) => {
let model_output: ModelOutputs = serde_json::from_value(value).unwrap();
return Ok(model_output)
},
None => return Err(CustomError)
}
},
Err(_) => return Err(CustomError)
}
}
依赖项
~6–20MB
~274K SLoC