12个版本
0.0.12 | 2023年2月9日 |
---|---|
0.0.11 | 2022年7月6日 |
0.0.10 | 2021年10月18日 |
0.0.9 | 2021年7月18日 |
0.0.1 | 2020年3月20日 |
#172 in 数学
41 每月下载次数
125KB
2K SLoC
ndarray-glm
使用迭代加权最小二乘法(IRLS)通过ndarray-linalg模块解决线性、逻辑和广义线性模型问题的Rust库。
状态
此包处于alpha版本,接口可能发生变化。甚至某些函数的返回值也可能在版本之间发生变化。不保证正确性。
回归算法使用迭代加权最小二乘法(IRLS),并在猜测的下一迭代没有增加似然时应用步长减半过程。
欢迎提出建议(通过问题)和拉取请求。
先决条件
建议使用系统BLAS实现。例如,在Debian/Ubuntu上安装OpenBLAS
sudo apt update && sudo apt install -y libopenblas-dev
然后使用带有openblas-system
功能的crate。
要使用备用后端或构建静态BLAS实现,请参阅ndarray-linalg
文档。使用此crate并使用适当的特性标志,它将被转发到ndarray-linalg
。
示例
要在您的crate中使用,请将以下内容添加到Cargo.toml
ndarray = { version = "0.15", features = ["blas"]}
ndarray-glm = { version = "0.0.12", features = ["openblas-system"] }
以下是一个线性回归的示例。
use ndarray_glm::{array, Linear, ModelBuilder, utility::standardize};
// define some test data
let data_y = array![0.3, 1.3, 0.7];
let data_x = array![[0.1, 0.2], [-0.4, 0.1], [0.2, 0.4]];
// The design matrix can optionally be standardized, where the mean of each independent
// variable is subtracted and each is then divided by the standard deviation of that variable.
let data_x = standardize(data_x);
let model = ModelBuilder::<Linear>::data(&data_y, &data_x).build()?;
// L2 (ridge) regularization can be applied with l2_reg().
let fit = model.fit_options().l2_reg(1e-5).fit()?;
// Currently the result is a simple array of the MLE estimators, including the intercept term.
println!("Fit result: {}", fit.result);
用户可以定义自定义非典型链接函数,尽管当前接口的可用性不是特别人性化。有关示例,请参阅tests/custom_link.rs
。
功能
- 线性回归
- 逻辑回归
- 广义线性模型IRLS
- 线性偏移量
- 通用于浮点类型
- 非浮点域类型
- 正则化
- L2(岭回归)
- L1(Lasso)
- Elastic Net
- 其他指数族分布
- Poisson
- 二项式
- 指数
- Gamma
- 逆高斯
- 数据标准化/归一化
- 外部实用函数
- 自动内部转换
- 加权(和相关?)回归
- 非典型链接函数
- 拟合优度测试
故障排除
在某些情况下,Lasso/L1正则化可能收敛缓慢,尤其是在数据表现不佳、可分离等情况。
以下建议是在遇到收敛问题时的推荐尝试事项,但在L1正则化问题中更有可能成为必需的。
- 标准化特征数据
- 使用f32而不是f64
- 增加容差和/或最大迭代次数
- 还包括一个小型的L2正则化。
如果在应用了这些技术之后问题仍然存在,请提交一个问题,以便改进算法。
参考
- 广义线性模型笔记
- Hardin & Hilbe著的《广义线性模型及其扩展》
依赖项
~71MB
~1M SLoC