3 个不稳定版本
0.2.0 | 2022 年 2 月 16 日 |
---|---|
0.1.1 | 2022 年 1 月 12 日 |
0.1.0 | 2022 年 1 月 12 日 |
#455 in 机器学习
每月 55 次下载
62KB
1.5K SLoC
Rust 优化运输
此库为 Rust 中执行正则化和非正则化优化运输提供求解器。
受 Python 优化运输 的启发,此库提供以下求解器
- 网络单纯形 算法用于线性规划/地球搬运工距离
- 包括 Sinkhorn Knopp 和贪婪 Sinkhorn 在内的熵正则化 OT 求解器
- 不平衡的 Sinkhorn Knopp
安装
该库已在 macOS 上进行了测试。它需要 C++ 编译器来构建 EMD 求解器,并依赖于以下 Rust 库
- cxx 1.0
- thiserror 1.0
- ndarray 0.15
Cargo 安装
使用以下内容编辑您的 Cargo.toml 以在项目中使用 rust-optimal-transport。
[dependencies]
rust-optimal-transport = "0.1"
功能
如果您想启用 LAPACK 后端(目前支持 OpenBLAS)
[dependencies]
rust-optimal-transport = { version = "0.1", features = ["blas"] }
这将链接到您系统上安装的 OpenBLAS 实例。有关更多详细信息,请参阅 ndarray-linalg 包。
示例
简短示例
- 导入库
use rust_optimal_transport as ot;
use ot::prelude::*;
- 计算 OT 矩阵作为地球搬运工距离
// Generate data
let n_samples = 100;
// Mean, Covariance of the source distribution
let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];
// Mean, Covariance of the target distribution
let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];
// Samples of a 2D gaussian distribution
let source = ot::utils::sample_2D_gauss(n_samples, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n_samples, &mu_target, &cov_target).unwrap();
// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
// Compute ground cost matrix - Squared Euclidean distance
let mut cost = dist(&source, &target, SqEuclidean);
let max_cost = cost.max().unwrap();
// Normalize cost matrix for numerical stability
cost = &cost / *max_cost;
// Compute optimal transport matrix as the Earth Mover's Distance
let ot_matrix = match EarthMovers::new(
&mut source_weights,
&mut target_weights,
&mut ground_cost
).solve()?;
致谢
此库受 Python 优化运输的启发。该项目的原始作者和贡献者名单列在 POT。
依赖项
~71MB
~1M SLoC