12个版本 (5个破坏性版本)
0.7.0 | 2024年6月25日 |
---|---|
0.6.0 | 2024年2月7日 |
0.5.0 | 2024年1月23日 |
0.3.0 | 2023年11月21日 |
0.1.4 | 2023年7月26日 |
#351 in 算法
每月58次下载
用于 3 个crate(2个直接使用)
580KB
15K SLoC
Lace:一种用于科学发现的概率机器学习工具
Lace是用Rust编写的概率跨分类引擎,可选地提供了对Python的接口。与传统的机器学习方法不同,传统的机器学习方法学习一些将输入映射到输出的函数,而Lace学习数据集上的联合概率分布,使用户能够...
- 预测或计算任何数量的特征在给定其他任何数量的特征条件下的概率
- 从数据中的方差、模型中的认知不确定性以及缺失特征中识别、量化并归因于不确定性
- 确定哪些变量是哪些其他变量的预测因素
- 确定哪些记录/行在整体或特定上下文中与哪些其他记录/行相似
- 模拟和操作合成数据
- 原生地处理缺失数据,并对缺失性(非随机缺失)进行推理
- 原生地处理连续和分类数据,无需转换
- 识别数据中的异常、错误和不一致性
- 编辑、回填和追加数据,无需重新训练
更多,所有都在一个地方,无需任何显式的模型构建。
import pandas as pd
import lace
# Create an engine from a dataframe
df = pd.read_csv("animals.csv", index_col=0)
engine = lace.Engine.from_df(df)
# Fit a model to the dataframe over 5000 steps of the fitting procedure
engine.update(5000)
# Show the statistical structure of the data -- which features are likely
# dependent (predictive) on each other
engine.clustermap("depprob", zmin=0, zmax=1)
问题
Lace的目标是弥合标准机器学习(ML)方法,如深度学习和随机森林,与统计方法,如概率编程语言之间的巨大差距。我们希望开发一种机器,让用户能够体验到发现的乐趣,并确实优化了它。
简短版本
标准、基于优化的机器学习方法并不能帮助你了解你的数据。概率编程工具假设你已经对数据有了很多了解。这两种方法都没有针对我们认为是数据科学最重要的部分进行优化:科学部分:提出和回答问题。
详细版本
标准的机器学习方法易于使用。你可以将数据扔进随机森林,几乎不用思考就能开始预测。这些方法试图学习一个函数 f(x) -> y,将输入 x 映射到输出 y。这种易用性是有代价的。通常 f(x) 并不能反映生成数据的实际过程,而是由开发这种方法的人选择的,以确保足够表达以更好地实现优化目标。这使得大多数标准的机器学习完全不可解释,也无法提供合理的不确定性估计。
在另一端,你有概率工具,如概率编程语言(PPL)。用户使用参数 θ 的概率分布层次结构将模型指定给 PPL。然后 PPL 使用一种程序(通常是马尔可夫链蒙特卡洛方法)来了解参数在给定数据 p(θ|x) 下的后验分布。PPL 强调可解释性和不确定性量化,但它们对用户提出了许多相当高的要求。PPL 用户必须从头开始指定模型,这意味着他们必须知道(或者至少猜测)模型。PPL 用户还必须知道如何以与底层推理程序兼容的方式指定这样的模型。
示例用例
- 结合数据源并理解它们如何相互作用。例如,我们可能希望从人口统计、调查或任务性能、心电图数据和其他临床数据中预测认知能力下降。结合这些数据通常非常稀疏(大多数患者不会填写所有字段),很难知道如何明确建模这些数据层的相互作用。在 Lace 中,我们只需连接这些层并将它们运行通过。
- 了解随着时间的推移不确定性的数量和原因。例如,一个农民可能希望了解在整个生长季节实现特定产量的可能性。随着季节的推移,可以添加新的天气数据作为条件来更新预测。不确定性可以表示为预测的方差、后验样本之间的不一致性或预测分布的多模态(有关不确定性的更多信息,请参阅这篇博客文章)。
- 数据质量控制。使用
surprisal
查找表中的异常数据,并使用-logp
在它们进入表之前识别异常。因为 Lace 创建了数据模型,我们还可以设计方法来查找与该模型不一致的数据,我们已经在错误查找中有效地使用了这种方法。
谁不应该使用 Lace
Lace 不适合以下用例
- 非表格数据,如图像和文本
- 高度优化特定预测
- Lace 更倾向于泛化而不是过拟合。
快速入门
安装
Lace 需要 rust。
安装 CLI
$ cargo install --locked lace-cli
安装 pylace
$ pip install pylace
示例
Lace 随带两个预配准示例数据集:卫星和动物。
>>> from lace.examples import Satellites
>>> engine = Satellites()
# Predict the class of orbit given the satellite has a 75-minute
# orbital period and that it has a missing value of geosynchronous
# orbit longitude, and return epistemic uncertainty via Jensen-
# Shannon divergence.
>>> engine.predict(
... 'Class_of_Orbit',
... given={
... 'Period_minutes': 75.0,
... 'longitude_radians_of_geo': None,
... },
... )
('LEO', 0.023981898950561048)
# Find the top 10 most surprising (anomalous) orbital periods in
# the table
>>> engine.surprisal('Period_minutes') \
... .sort('surprisal', reverse=True) \
... .head(10)
shape: (10, 3)
┌─────────────────────────────────────┬────────────────┬───────────┐
│ index ┆ Period_minutes ┆ surprisal │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞═════════════════════════════════════╪════════════════╪═══════════╡
│ Wind (International Solar-Terres... ┆ 19700.45 ┆ 11.019368 │
│ Integral (INTErnational Gamma-Ra... ┆ 4032.86 ┆ 9.556746 │
│ Chandra X-Ray Observatory (CXO) ┆ 3808.92 ┆ 9.477986 │
│ Tango (part of Cluster quartet, ... ┆ 3442.0 ┆ 9.346999 │
│ ... ┆ ... ┆ ... │
│ Salsa (part of Cluster quartet, ... ┆ 3418.2 ┆ 9.338377 │
│ XMM Newton (High Throughput X-ra... ┆ 2872.15 ┆ 9.13493 │
│ Geotail (Geomagnetic Tail Labora... ┆ 2474.83 ┆ 8.981458 │
│ Interstellar Boundary EXplorer (... ┆ 0.22 ┆ 8.884579 │
└─────────────────────────────────────┴────────────────┴───────────┘
同样在 rust 中
use lace::prelude::*;
use lace::examples::Example;
fn main() {
// In rust, you can create an Engine or and Oracle. The Oracle is an
// immutable version of an Engine; it has the same inference functions as
// the Engine, but you cannot train or edit data.
let mut engine = Example::Satellites.engine().unwrap();
// Predict the class of orbit given the satellite has a 75-minute
// orbital period and that it has a missing value of geosynchronous
// orbit longitude, and return epistemic uncertainty via Jensen-
// Shannon divergence.
engine.predict(
"Class_of_Orbit",
&Given::Conditions(vec![
("Period_minutes", Datum:Continuous(75.0)),
("Longitude_of_radians_geo", Datum::Missing),
]),
Some(PredictUncertaintyType::JsDivergence),
None,
)
}
拟合模型
要将自己的数据拟合到模型,您可以使用 CLI
$ lace run --csv my-data.csv -n 1000 my-data.lace
...或从文件或数据框初始化引擎。
>>> import pandas as pd # Lace supports polars as well
>>> from lace import Engine
>>> engine = Engine.from_df(pd.read_csv("my-data.csv", index_col=0))
>>> engine.update(1_000)
>>> engine.save("my-data.lace")
您可以使用诊断图监控训练进度
>>> from lace.plot import diagnostics
>>> diagnostics(engine)
许可
蕾丝授权基于商业源代码许可证v1.1,该许可证限制了商业用途。请参阅LICENSE
以获取详细信息。
如果您希望在商业用途中使用,请联系lace@redpoll.ai
学术用途
蕾丝对学术用途免费。请根据CITATION.cff
元数据引用蕾丝。
依赖项
~32–62MB
~1M SLoC