3 个不稳定版本

0.2.0 2022 年 2 月 16 日
0.1.1 2022 年 1 月 12 日
0.1.0 2022 年 1 月 12 日

#455 in 机器学习

Download history 38/week @ 2024-03-01 30/week @ 2024-03-08 24/week @ 2024-03-15 1/week @ 2024-03-22 17/week @ 2024-03-29 1/week @ 2024-04-05 3/week @ 2024-05-24 11/week @ 2024-05-31 28/week @ 2024-06-07 13/week @ 2024-06-14

每月 55 次下载

MIT 许可证

62KB
1.5K SLoC

Rust 优化运输

此库为 Rust 中执行正则化和非正则化优化运输提供求解器。

Python 优化运输 的启发,此库提供以下求解器

  • 网络单纯形 算法用于线性规划/地球搬运工距离
  • 包括 Sinkhorn Knopp 和贪婪 Sinkhorn 在内的熵正则化 OT 求解器
  • 不平衡的 Sinkhorn Knopp

安装

该库已在 macOS 上进行了测试。它需要 C++ 编译器来构建 EMD 求解器,并依赖于以下 Rust 库

  • cxx 1.0
  • thiserror 1.0
  • ndarray 0.15

Cargo 安装

使用以下内容编辑您的 Cargo.toml 以在项目中使用 rust-optimal-transport。

[dependencies]
rust-optimal-transport = "0.1"

功能

如果您想启用 LAPACK 后端(目前支持 OpenBLAS)

[dependencies]
rust-optimal-transport = { version = "0.1", features = ["blas"] }

这将链接到您系统上安装的 OpenBLAS 实例。有关更多详细信息,请参阅 ndarray-linalg 包。

示例

简短示例

  • 导入库
use rust_optimal_transport as ot;
use ot::prelude::*;

  • 计算 OT 矩阵作为地球搬运工距离
// Generate data
let n_samples = 100;

// Mean, Covariance of the source distribution
let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];

// Mean, Covariance of the target distribution
let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];

// Samples of a 2D gaussian distribution
let source = ot::utils::sample_2D_gauss(n_samples, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n_samples, &mu_target, &cov_target).unwrap();

// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));

// Compute ground cost matrix - Squared Euclidean distance
let mut cost = dist(&source, &target, SqEuclidean);
let max_cost = cost.max().unwrap();

// Normalize cost matrix for numerical stability
cost = &cost / *max_cost;

// Compute optimal transport matrix as the Earth Mover's Distance
let ot_matrix = match EarthMovers::new(
    &mut source_weights,
    &mut target_weights,
    &mut ground_cost
).solve()?;

致谢

此库受 Python 优化运输的启发。该项目的原始作者和贡献者名单列在 POT

依赖项

~71MB
~1M SLoC