7个不稳定版本 (3个破坏性版本)

0.4.0 2024年6月25日
0.3.0 2024年2月7日
0.2.1 2024年1月23日
0.1.3 2023年10月16日
0.1.0 2023年4月21日

#215 in 机器学习

每月 23 次下载
用于 4 个crate(2个直接使用)

BUSL-1.1

205KB
5.5K SLoC

Lace:一种用于科学发现的概率机器学习工具



安装Rust | Python | CLI


Lace是一个用Rust编写的概率交叉分类引擎,具有可选的Python接口。与传统的机器学习方法不同,它学习数据集上的联合概率分布,使用户能够...

  • 预测或计算任何数量特征在给定任何数量其他特征条件下的可能性
  • 从数据的方差、模型中的认知不确定性以及缺失特征中识别、量化和归因不确定性
  • 确定哪些变量可以预测其他哪些变量
  • 确定哪些记录/行在整体或特定上下文中与其他哪些记录/行相似
  • 模拟和操作合成数据
  • 原生地处理缺失数据,并对缺失性(非随机缺失)进行推理
  • 原生地处理连续和分类数据,无需转换
  • 识别数据中的异常、错误和不一致性
  • 编辑、填充和追加数据,无需重新训练

更多,所有这些都在一个地方,无需任何显式的模型构建。

import pandas as pd
import lace

# Create an engine from a dataframe
df = pd.read_csv("animals.csv", index_col=0)
engine = lace.Engine.from_df(df)

# Fit a model to the dataframe over 5000 steps of the fitting procedure
engine.update(5000)

# Show the statistical structure of the data -- which features are likely
# dependent (predictive) on each other
engine.clustermap("depprob", zmin=0, zmax=1)

Animals dataset dependence probability

问题

蕾丝的目标是弥合标准机器学习(ML)方法(如深度学习和随机森林)与统计方法(如概率编程语言)之间巨大的差距。我们希望开发一种机器,让用户能够体验到发现的乐趣,并且确实为此进行了优化。

简版

基于标准的、基于优化的ML方法无法帮助您了解您的数据。概率编程工具假设您已经对您的数据有了很多了解。这两种方法都没有针对我们认为数据科学最重要的部分进行优化:科学部分:提问和回答问题。

长版

标准的ML方法易于使用。您可以将数据扔进随机森林,几乎不用思考就能开始预测。这些方法试图学习一个函数f(x) -> y,该函数将输入x映射到输出y。这种易用性是有代价的。通常f(x)并不反映生成您数据的现实过程,而是由开发该方法的任何人选择,以便足够表达以更好地实现优化目标。这使得大多数标准的ML完全无法解释,并且无法提供合理的不确定性估计。

在另一端,您有概率工具,如概率编程语言(PPL)。用户以参数θ的概率分布层次结构的形式指定一个模型给PPL。然后,PPL使用一个程序(通常是马尔可夫链蒙特卡洛)来学习给定数据p(θ|x)的参数的后验分布。PPL都关于可解释性和不确定性量化,但它们对用户提出了许多相当严格的要求。PPL用户必须从头开始指定模型,这意味着他们必须知道(或者至少猜测)模型。PPL用户还必须知道如何以与底层推理过程兼容的方式指定此类模型。

示例用例

  • 合并数据源并了解它们之间的交互。例如,我们可能希望从人口统计、调查或任务表现、心电图数据和其他临床数据中预测认知能力下降。结合这些数据,通常会很稀疏(大多数患者不会填写所有字段),很难知道如何明确地建模这些数据层的交互。在Lace中,我们只需连接这些层并运行即可。
  • 了解随时间推移的不确定性的量和原因。例如,一个农民可能希望了解在整个生长季节实现特定产量的可能性。随着季节的进行,可以添加新的天气数据以条件的形式进行预测。不确定性可以表示为预测中的方差、后验样本之间的不一致性或预测分布的多模态性(有关不确定性的更多信息,请参阅这篇博客文章
  • 数据质量控制。使用 surprisal 在表中查找异常数据,并使用 -logp 在它们进入表之前识别异常。由于Lace创建了数据模型,我们还可以设计方法来找到与该模型 不一致 的数据,我们已有效地将其用于错误查找。

谁不应该使用Lace

有一些用例不适合Lace

  • 非表格数据,如图像和文本
  • 高度优化特定预测
    • Lace宁愿泛化过度也不愿过拟合。

快速入门

安装

Lace需要Rust。

安装CLI

$ cargo install --locked lace-cli

安装pylace

$ pip install pylace

示例

Lace附带两个预配的示例数据集:卫星和动物。

>>> from lace.examples import Satellites
>>> engine = Satellites()

# Predict the class of orbit given the satellite has a 75-minute
# orbital period and that it has a missing value of geosynchronous
# orbit longitude, and return epistemic uncertainty via Jensen-
# Shannon divergence.
>>> engine.predict(
...     'Class_of_Orbit',
...     given={
...         'Period_minutes': 75.0,
...         'longitude_radians_of_geo': None,
...     },
... )
('LEO', 0.023981898950561048)

# Find the top 10 most surprising (anomalous) orbital periods in
# the table
>>> engine.surprisal('Period_minutes') \
...     .sort('surprisal', reverse=True) \
...     .head(10)
shape: (10, 3)
┌─────────────────────────────────────┬────────────────┬───────────┐
│ index                               ┆ Period_minutes ┆ surprisal │
│ ---                                 ┆ ---            ┆ ---       │
│ str                                 ┆ f64            ┆ f64       │
╞═════════════════════════════════════╪════════════════╪═══════════╡
│ Wind (International Solar-Terres... ┆ 19700.45       ┆ 11.019368 │
│ Integral (INTErnational Gamma-Ra... ┆ 4032.86        ┆ 9.556746  │
│ Chandra X-Ray Observatory (CXO)     ┆ 3808.92        ┆ 9.477986  │
│ Tango (part of Cluster quartet, ... ┆ 3442.0         ┆ 9.346999  │
│ ...                                 ┆ ...            ┆ ...       │
│ Salsa (part of Cluster quartet, ... ┆ 3418.2         ┆ 9.338377  │
│ XMM Newton (High Throughput X-ra... ┆ 2872.15        ┆ 9.13493   │
│ Geotail (Geomagnetic Tail Labora... ┆ 2474.83        ┆ 8.981458  │
│ Interstellar Boundary EXplorer (... ┆ 0.22           ┆ 8.884579  │
└─────────────────────────────────────┴────────────────┴───────────┘

同样在Rust中

use lace::prelude::*;
use lace::examples::Example;

fn main() {	
    // In rust, you can create an Engine or and Oracle. The Oracle is an
    // immutable version of an Engine; it has the same inference functions as
    // the Engine, but you cannot train or edit data.
    let mut engine = Example::Satellites.engine().unwrap();
	
    // Predict the class of orbit given the satellite has a 75-minute
    // orbital period and that it has a missing value of geosynchronous
    // orbit longitude, and return epistemic uncertainty via Jensen-
    // Shannon divergence.
    engine.predict(
        "Class_of_Orbit",
        &Given::Conditions(vec![
            ("Period_minutes", Datum:Continuous(75.0)),
            ("Longitude_of_radians_geo", Datum::Missing),
        ]),
        Some(PredictUncertaintyType::JsDivergence),
        None,
    )
}

拟合模型

要将模型拟合到您自己的数据中,您可以使用命令行界面。

$ lace run --csv my-data.csv -n 1000 my-data.lace

...或者从文件或数据框初始化引擎。

>>> import pandas as pd  # Lace supports polars as well
>>> from lace import Engine
>>> engine = Engine.from_df(pd.read_csv("my-data.csv", index_col=0))
>>> engine.update(1_000)
>>> engine.save("my-data.lace")

您可以使用诊断图来监控训练进度。

>>> from lace.plot import diagnostics
>>> diagnostics(engine)

Animals MCMC convergence

许可证

Lace遵循商业源代码许可证v1.1,该许可证限制了商业使用。有关详细信息,请参阅LICENSE

如果您想要用于商业用途的许可证,请联系lace@redpoll.ai

学术用途

Lace适用于学术用途是免费的。请根据CITATION.cff元数据引用lace。


lib.rs:

Geweke(联合分布)测试

依赖项

~12–20MB
~295K SLoC