10 个稳定版本
2.5.0 | 2024年6月18日 |
---|---|
2.4.1 | 2024年4月17日 |
2.4.0 | 2024年3月25日 |
2.3.0 | 2023年12月11日 |
2.1.1 | 2023年6月30日 |
#232 in 算法
124 个月下载量
37KB
655 行
来自 Rust 的 BridgeStan
在 Github Pages 上查看 BridgeStan 文档.
这是 BridgeStan 的 Rust 封装。它允许用户从 Rust 本地评估 Stan 模型的似然值和相关函数。
内部,它依赖于 bindgen
和 libloading
。
编译模型
Rust 封装可以通过调用 compile_model
函数来调用 make
命令,从而编译 Stan 模型。
这需要一个 C++ 工具链和 BridgeStan 源代码副本。可以通过启用 download-bridgestan-src
功能并调用 download_bridgestan_src
自动下载源代码。或者,可以手动提供 BridgeStan 源代码的路径。
出于安全考虑,所有 Stan 模型都需要使用 STAN_THREADS=true
构建。这是 compile_model
函数中的默认行为,但在其他上下文中编译模型时可能需要手动设置。
如果在构建模型时未指定 STAN_THREADS
,Rust 封装在加载模型时将抛出错误。
使用方法
使用以下命令运行此示例: cargo run --example=example
。
use std::ffi::CString;
use std::path::{Path, PathBuf};
use bridgestan::{BridgeStanError, Model, open_library, compile_model};
// The path to the Stan model
let path = Path::new(env!["CARGO_MANIFEST_DIR"])
.parent()
.unwrap()
.join("test_models/simple/simple.stan");
// You can manually set the BridgeStan src path or
// automatically download it (but remember to
// enable the download-bridgestan-src feature first)
let bs_path: PathBuf = "..".into();
// let bs_path = bridgestan::download_bridgestan_src().unwrap();
// The path to the compiled model
let path = compile_model(&bs_path, &path, &[], &[]).expect("Could not compile Stan model.");
println!("Compiled model: {:?}", path);
let lib = open_library(path).expect("Could not load compiled Stan model.");
// The dataset as json
let data = r#"{"N": 7}"#;
let data = CString::new(data.to_string().into_bytes()).unwrap();
// The seed is used in case the model contains a transformed data section
// that uses rng functions.
let seed = 42;
let model = match Model::new(&lib, Some(data), seed) {
Ok(model) => model,
Err(BridgeStanError::ConstructFailed(msg)) => {
panic!("Model initialization failed. Error message from Stan was {msg}")
}
Err(e) => {
panic!("Unexpected error:\n{e}")
}
};
let n_dim = model.param_unc_num();
assert_eq!(n_dim, 7);
let point = vec![1f64; n_dim];
let mut gradient_out = vec![0f64; n_dim];
let logp = model.log_density_gradient(&point[..], true, true, &mut gradient_out[..])
.expect("Stan failed to evaluate the logp function.");
// gradient_out contains the gradient of the logp density
依赖项
~0.6–12MB
~142K SLoC