7个不稳定版本 (3个破坏性版本)
0.4.0 | 2024年6月25日 |
---|---|
0.3.0 | 2024年2月7日 |
0.2.1 | 2024年1月23日 |
0.1.3 | 2023年10月16日 |
0.1.0 | 2023年4月21日 |
#215 in 机器学习
每月 23 次下载
用于 4 个crate(2个直接使用)
205KB
5.5K SLoC
Lace:一种用于科学发现的概率机器学习工具
Lace是一个用Rust编写的概率交叉分类引擎,具有可选的Python接口。与传统的机器学习方法不同,它学习数据集上的联合概率分布,使用户能够...
- 预测或计算任何数量特征在给定任何数量其他特征条件下的可能性
- 从数据的方差、模型中的认知不确定性以及缺失特征中识别、量化和归因不确定性
- 确定哪些变量可以预测其他哪些变量
- 确定哪些记录/行在整体或特定上下文中与其他哪些记录/行相似
- 模拟和操作合成数据
- 原生地处理缺失数据,并对缺失性(非随机缺失)进行推理
- 原生地处理连续和分类数据,无需转换
- 识别数据中的异常、错误和不一致性
- 编辑、填充和追加数据,无需重新训练
更多,所有这些都在一个地方,无需任何显式的模型构建。
import pandas as pd
import lace
# Create an engine from a dataframe
df = pd.read_csv("animals.csv", index_col=0)
engine = lace.Engine.from_df(df)
# Fit a model to the dataframe over 5000 steps of the fitting procedure
engine.update(5000)
# Show the statistical structure of the data -- which features are likely
# dependent (predictive) on each other
engine.clustermap("depprob", zmin=0, zmax=1)
问题
蕾丝的目标是弥合标准机器学习(ML)方法(如深度学习和随机森林)与统计方法(如概率编程语言)之间巨大的差距。我们希望开发一种机器,让用户能够体验到发现的乐趣,并且确实为此进行了优化。
简版
基于标准的、基于优化的ML方法无法帮助您了解您的数据。概率编程工具假设您已经对您的数据有了很多了解。这两种方法都没有针对我们认为数据科学最重要的部分进行优化:科学部分:提问和回答问题。
长版
标准的ML方法易于使用。您可以将数据扔进随机森林,几乎不用思考就能开始预测。这些方法试图学习一个函数f(x) -> y,该函数将输入x映射到输出y。这种易用性是有代价的。通常f(x)并不反映生成您数据的现实过程,而是由开发该方法的任何人选择,以便足够表达以更好地实现优化目标。这使得大多数标准的ML完全无法解释,并且无法提供合理的不确定性估计。
在另一端,您有概率工具,如概率编程语言(PPL)。用户以参数θ的概率分布层次结构的形式指定一个模型给PPL。然后,PPL使用一个程序(通常是马尔可夫链蒙特卡洛)来学习给定数据p(θ|x)的参数的后验分布。PPL都关于可解释性和不确定性量化,但它们对用户提出了许多相当严格的要求。PPL用户必须从头开始指定模型,这意味着他们必须知道(或者至少猜测)模型。PPL用户还必须知道如何以与底层推理过程兼容的方式指定此类模型。
示例用例
- 合并数据源并了解它们之间的交互。例如,我们可能希望从人口统计、调查或任务表现、心电图数据和其他临床数据中预测认知能力下降。结合这些数据,通常会很稀疏(大多数患者不会填写所有字段),很难知道如何明确地建模这些数据层的交互。在Lace中,我们只需连接这些层并运行即可。
- 了解随时间推移的不确定性的量和原因。例如,一个农民可能希望了解在整个生长季节实现特定产量的可能性。随着季节的进行,可以添加新的天气数据以条件的形式进行预测。不确定性可以表示为预测中的方差、后验样本之间的不一致性或预测分布的多模态性(有关不确定性的更多信息,请参阅这篇博客文章)
- 数据质量控制。使用
surprisal
在表中查找异常数据,并使用-logp
在它们进入表之前识别异常。由于Lace创建了数据模型,我们还可以设计方法来找到与该模型 不一致 的数据,我们已有效地将其用于错误查找。
谁不应该使用Lace
有一些用例不适合Lace
- 非表格数据,如图像和文本
- 高度优化特定预测
- Lace宁愿泛化过度也不愿过拟合。
快速入门
安装
Lace需要Rust。
安装CLI
$ cargo install --locked lace-cli
安装pylace
$ pip install pylace
示例
Lace附带两个预配的示例数据集:卫星和动物。
>>> from lace.examples import Satellites
>>> engine = Satellites()
# Predict the class of orbit given the satellite has a 75-minute
# orbital period and that it has a missing value of geosynchronous
# orbit longitude, and return epistemic uncertainty via Jensen-
# Shannon divergence.
>>> engine.predict(
... 'Class_of_Orbit',
... given={
... 'Period_minutes': 75.0,
... 'longitude_radians_of_geo': None,
... },
... )
('LEO', 0.023981898950561048)
# Find the top 10 most surprising (anomalous) orbital periods in
# the table
>>> engine.surprisal('Period_minutes') \
... .sort('surprisal', reverse=True) \
... .head(10)
shape: (10, 3)
┌─────────────────────────────────────┬────────────────┬───────────┐
│ index ┆ Period_minutes ┆ surprisal │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞═════════════════════════════════════╪════════════════╪═══════════╡
│ Wind (International Solar-Terres... ┆ 19700.45 ┆ 11.019368 │
│ Integral (INTErnational Gamma-Ra... ┆ 4032.86 ┆ 9.556746 │
│ Chandra X-Ray Observatory (CXO) ┆ 3808.92 ┆ 9.477986 │
│ Tango (part of Cluster quartet, ... ┆ 3442.0 ┆ 9.346999 │
│ ... ┆ ... ┆ ... │
│ Salsa (part of Cluster quartet, ... ┆ 3418.2 ┆ 9.338377 │
│ XMM Newton (High Throughput X-ra... ┆ 2872.15 ┆ 9.13493 │
│ Geotail (Geomagnetic Tail Labora... ┆ 2474.83 ┆ 8.981458 │
│ Interstellar Boundary EXplorer (... ┆ 0.22 ┆ 8.884579 │
└─────────────────────────────────────┴────────────────┴───────────┘
同样在Rust中
use lace::prelude::*;
use lace::examples::Example;
fn main() {
// In rust, you can create an Engine or and Oracle. The Oracle is an
// immutable version of an Engine; it has the same inference functions as
// the Engine, but you cannot train or edit data.
let mut engine = Example::Satellites.engine().unwrap();
// Predict the class of orbit given the satellite has a 75-minute
// orbital period and that it has a missing value of geosynchronous
// orbit longitude, and return epistemic uncertainty via Jensen-
// Shannon divergence.
engine.predict(
"Class_of_Orbit",
&Given::Conditions(vec![
("Period_minutes", Datum:Continuous(75.0)),
("Longitude_of_radians_geo", Datum::Missing),
]),
Some(PredictUncertaintyType::JsDivergence),
None,
)
}
拟合模型
要将模型拟合到您自己的数据中,您可以使用命令行界面。
$ lace run --csv my-data.csv -n 1000 my-data.lace
...或者从文件或数据框初始化引擎。
>>> import pandas as pd # Lace supports polars as well
>>> from lace import Engine
>>> engine = Engine.from_df(pd.read_csv("my-data.csv", index_col=0))
>>> engine.update(1_000)
>>> engine.save("my-data.lace")
您可以使用诊断图来监控训练进度。
>>> from lace.plot import diagnostics
>>> diagnostics(engine)
许可证
Lace遵循商业源代码许可证v1.1,该许可证限制了商业使用。有关详细信息,请参阅LICENSE
。
如果您想要用于商业用途的许可证,请联系lace@redpoll.ai
。
学术用途
Lace适用于学术用途是免费的。请根据CITATION.cff
元数据引用lace。
lib.rs
:
Geweke(联合分布)测试
依赖项
~12–20MB
~295K SLoC