6 个版本
0.2.0 | 2023 年 3 月 29 日 |
---|---|
0.1.4 | 2022 年 3 月 17 日 |
0.1.3 | 2022 年 1 月 2 日 |
0.1.2 | 2021 年 12 月 21 日 |
#1180 in 算法
用于 宇宙学
15KB
228 行
hammer-and-sample
基于 emcee 的简单 MCMC 样本器,实现仿射不变的集合采样,具有串行执行,并可基于 Rayon 并行执行。
实现相对高效,例如,使用来自 hierarchical.rs
的层次模型,使用 emcee
和 multiprocessing
计算使用 1000 次迭代和 100 个游走者大约需要 1 分钟,与在 8 个硬件线程上运行相同的 crate Rayon 一样。
lib.rs
:
基于 emcee 的简单 MCMC 集合样本器
use hammer_and_sample::{sample, MinChainLen, Model, Serial};
use rand::{Rng, SeedableRng};
use rand_pcg::Pcg64;
fn estimate_bias(coin_flips: &[bool]) -> f64 {
struct CoinFlips<'a>(&'a [bool]);
impl Model for CoinFlips<'_> {
type Params = [f64; 1];
// likelihood of Bernoulli distribution and uninformative prior
fn log_prob(&self, &[p]: &Self::Params) -> f64 {
if p < 0. || p > 1. {
return f64::NEG_INFINITY;
}
let ln_p = p.ln();
let ln_1_p = (1. - p).ln();
self.0
.iter()
.map(|coin_flip| if *coin_flip { ln_p } else { ln_1_p })
.sum()
}
}
let model = CoinFlips(coin_flips);
let walkers = (0..10).map(|seed| {
let mut rng = Pcg64::seed_from_u64(seed);
let p = rng.gen_range(0.0..=1.0);
([p], rng)
});
let (chain, _accepted) = sample(&model, walkers, MinChainLen(10 * 1000), Serial);
// 100 iterations of 10 walkers as burn-in
let chain = &chain[10 * 100..];
chain.iter().map(|&[p]| p).sum::<f64>() / chain.len() as f64
}
依赖项
~245–580KB
~11K SLoC