15个版本 (破坏性)

0.12.1 2024年7月9日
0.11.0 2024年6月11日
0.8.0 2023年9月20日
0.6.0 2023年7月21日
0.2.1 2022年7月20日

#28 in 科学

Download history 23/week @ 2024-04-22 76/week @ 2024-04-29 11/week @ 2024-05-06 34/week @ 2024-05-13 23/week @ 2024-05-20 288/week @ 2024-05-27 50/week @ 2024-06-03 207/week @ 2024-06-10 106/week @ 2024-06-17 135/week @ 2024-06-24 154/week @ 2024-07-01 213/week @ 2024-07-08 20/week @ 2024-07-15 99/week @ 2024-07-22 104/week @ 2024-07-29 21/week @ 2024-08-05

每月253次 下载

MIT 许可证

190KB
5K SLoC

Workflow Status dependency status

使用No U-turn采样器(NUTS)从后验分布中采样。有关详细信息,请参阅原始的 NUTS论文 和更近期的 介绍

此包是为了作为PyMC中采样器的更快的替代品而开发的,用于与PyTensor的新numba后端一起使用。此采样器的Python包装器为 nutpie

用法

use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, SampleStats,
Settings};
use thiserror::Error;
use rand::thread_rng;

// Define a function that computes the unnormalized posterior density
// and its gradient.
#[derive(Debug)]
struct PosteriorDensity {}

// The density might fail in a recoverable or non-recoverable manner...
#[derive(Debug, Error)]
enum PosteriorLogpError {}
impl LogpError for PosteriorLogpError {
    fn is_recoverable(&self) -> bool { false }
}

impl CpuLogpFunc for PosteriorDensity {
    type LogpError = PosteriorLogpError;

    // We define a 10 dimensional normal distribution
    fn dim(&self) -> usize { 10 }

    // The normal likelihood with mean 3 and its gradient.
    fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
        let mu = 3f64;
        let logp = position
            .iter()
            .copied()
            .zip(grad.iter_mut())
            .map(|(x, grad)| {
                let diff = x - mu;
                *grad = -diff;
                -diff * diff / 2f64
            })
            .sum();
        return Ok(logp)
    }
}

// We get the default sampler arguments
let mut settings = DiagGradNutsSettings::default();

// and modify as we like
settings.num_tune = 1000;
settings.maxdepth = 3;  // small value just for testing...

// We instanciate our posterior density function
let logp_func = PosteriorDensity {};
let math = CpuMath::new(logp_func);

let chain = 0;
let mut rng = thread_rng();
let mut sampler = settings.new_chain(0, math, &mut rng);

// Set to some initial position and start drawing samples.
sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init");
let mut trace = vec![];  // Collection of all draws
for _ in 0..2000 {
    let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
    trace.push(draw);
}

用户还可以实现 Model 特性,以获得更多控制和并行采样。

实现细节

此包主要遵循在 StanPyMC 中NUTS的实现,只是质量矩阵和步长的调整有所不同。

依赖关系

~22–30MB
~497K SLoC