5个不稳定版本
0.3.0 | 2024年2月7日 |
---|---|
0.2.0 | 2024年1月23日 |
0.1.2 | 2023年6月8日 |
0.1.1 | 2023年6月5日 |
0.1.0 | 2023年4月21日 |
#335 in 机器学习
93 每月下载量
在7 个crate(5个直接)中使用
110KB
3K SLoC
Lace:一种用于科学发现的概率机器学习工具
Lace是一个用Rust编写的概率跨分类引擎,可选的Python接口。与传统机器学习方法不同,后者学习将输入映射到输出的某些函数,Lace学习您的数据集上的联合概率分布,使用户能够...
- 预测或计算任何数量的特征在给定其他任何数量的特征条件下的可能性
- 从数据中的方差、模型中的认识不确定性以及缺失特征中识别、量化并归因于不确定性
- 确定哪些变量是预测其他变量的
- 确定哪些记录/行在整个或特定上下文中与哪些其他记录相似
- 模拟和操作合成数据
- 与缺失数据原生交互,并就缺失性(缺失非随机)进行推理
- 与连续和分类数据原生交互,无需转换
- 在数据中识别异常、错误和不一致性
- 编辑、回填和追加数据而无需重新训练
等等,所有都在一个地方,无需任何显式模型构建。
import pandas as pd
import lace
# Create an engine from a dataframe
df = pd.read_csv("animals.csv", index_col=0)
engine = lace.Engine.from_df(df)
# Fit a model to the dataframe over 5000 steps of the fitting procedure
engine.update(5000)
# Show the statistical structure of the data -- which features are likely
# dependent (predictive) on each other
engine.clustermap("depprob", zmin=0, zmax=1)
问题
Lace的目标是填补标准机器学习(ML)方法(如深度学习和随机森林)与统计方法(如概率编程语言)之间巨大的差距。我们希望开发一种机器,使用户能够体验到发现的乐趣,并且确实对此进行了优化。
简短版本
标准的基于优化的机器学习方法并不能帮助你了解你的数据。概率编程工具则假设你已经对数据有了很多了解。这两种方法都没有针对数据科学最重要的部分——科学部分进行优化:提出和回答问题。
详细版本
标准的机器学习方法易于使用。你只需将数据扔进随机森林,就可以开始预测而无需过多思考。这些方法试图学习一个函数 f(x) -> y,将输入 x 映射到输出 y。这种易用性是有代价的。一般来说,f(x) 并不能反映生成你数据的现实过程,而是由开发者根据需要表达足够以更好地实现优化目标而选择。这使得大多数标准机器学习模型无法解释,也无法提供合理的不确定性估计。
在另一个极端,你有概率工具,如概率编程语言(PPL)。用户以概率分布的层次结构(带有参数 θ)的形式指定一个模型给 PPL。然后,PPL 使用一种程序(通常是马尔可夫链蒙特卡罗方法)来学习参数的后验分布 p(θ|x)。PPL 强调可解释性和不确定性量化,但对用户提出了一系列相当严格的要求。PPL 用户必须从头开始自己指定模型,这意味着他们必须知道(或者至少猜测)模型。PPL 用户还必须知道如何以与底层推理程序兼容的方式指定此类模型。
示例用例
- 结合数据源并理解它们如何相互作用。例如,我们可能希望根据人口统计、调查或任务表现、心电图数据和其他临床数据来预测认知下降。结合这些数据通常非常稀疏(大多数患者不会填写所有字段),很难知道如何显式地建模这些数据层之间的相互作用。在 Lace 中,我们只需连接这些层并将它们运行通过。
- 理解随时间的不确定性和原因。例如,一个农民可能希望了解在整个生长季节实现特定产量的可能性。随着季节的进行,可以将新的天气数据作为条件添加到预测中。不确定性可以可视化预测中的方差、后验样本之间的不一致性或预测分布的多模态(有关不确定性的更多信息,请参阅这篇博客文章)
- 数据质量控制。使用
surprisal
来找到表格中的异常数据,并使用-logp
在它们进入表格之前识别异常。因为 Lace 会创建一个数据模型,所以我们也可以设计方法来找到与该模型不一致的数据,我们已经在错误发现中有效地使用了这种方法。
谁不应使用 Lace
Lace 并不适合以下用例
- 非表格数据,如图像和文本
- 高度优化特定预测
- Lace 宁愿泛化过度也不愿过拟合。
快速入门
安装
Lace 需要 rust。
安装 CLI
$ cargo install --locked lace-cli
安装 pylace
$ pip install pylace
示例
Lace 随附两个预配准示例数据集:卫星和动物。
>>> from lace.examples import Satellites
>>> engine = Satellites()
# Predict the class of orbit given the satellite has a 75-minute
# orbital period and that it has a missing value of geosynchronous
# orbit longitude, and return epistemic uncertainty via Jensen-
# Shannon divergence.
>>> engine.predict(
... 'Class_of_Orbit',
... given={
... 'Period_minutes': 75.0,
... 'longitude_radians_of_geo': None,
... },
... )
('LEO', 0.023981898950561048)
# Find the top 10 most surprising (anomalous) orbital periods in
# the table
>>> engine.surprisal('Period_minutes') \
... .sort('surprisal', reverse=True) \
... .head(10)
shape: (10, 3)
┌─────────────────────────────────────┬────────────────┬───────────┐
│ index ┆ Period_minutes ┆ surprisal │
│ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 │
╞═════════════════════════════════════╪════════════════╪═══════════╡
│ Wind (International Solar-Terres... ┆ 19700.45 ┆ 11.019368 │
│ Integral (INTErnational Gamma-Ra... ┆ 4032.86 ┆ 9.556746 │
│ Chandra X-Ray Observatory (CXO) ┆ 3808.92 ┆ 9.477986 │
│ Tango (part of Cluster quartet, ... ┆ 3442.0 ┆ 9.346999 │
│ ... ┆ ... ┆ ... │
│ Salsa (part of Cluster quartet, ... ┆ 3418.2 ┆ 9.338377 │
│ XMM Newton (High Throughput X-ra... ┆ 2872.15 ┆ 9.13493 │
│ Geotail (Geomagnetic Tail Labora... ┆ 2474.83 ┆ 8.981458 │
│ Interstellar Boundary EXplorer (... ┆ 0.22 ┆ 8.884579 │
└─────────────────────────────────────┴────────────────┴───────────┘
在 rust 中也是类似的
use lace::prelude::*;
use lace::examples::Example;
fn main() {
// In rust, you can create an Engine or and Oracle. The Oracle is an
// immutable version of an Engine; it has the same inference functions as
// the Engine, but you cannot train or edit data.
let mut engine = Example::Satellites.engine().unwrap();
// Predict the class of orbit given the satellite has a 75-minute
// orbital period and that it has a missing value of geosynchronous
// orbit longitude, and return epistemic uncertainty via Jensen-
// Shannon divergence.
engine.predict(
"Class_of_Orbit",
&Given::Conditions(vec![
("Period_minutes", Datum:Continuous(75.0)),
("Longitude_of_radians_geo", Datum::Missing),
]),
Some(PredictUncertaintyType::JsDivergence),
None,
)
}
拟合模型
要使用 CLI 将模型拟合到自己的数据
$ lace run --csv my-data.csv -n 1000 my-data.lace
...或者从一个文件或数据框初始化引擎。
>>> import pandas as pd # Lace supports polars as well
>>> from lace import Engine
>>> engine = Engine.from_df(pd.read_csv("my-data.csv", index_col=0))
>>> engine.update(1_000)
>>> engine.save("my-data.lace")
你可以使用诊断图监控训练进度
>>> from lace.plot import diagnostics
>>> diagnostics(engine)
许可证
蕾丝许可协议为商业源代码许可证v1.1,限制商业用途。请参阅 LICENSE
获取详细信息。
如需用于商业用途的许可,请联系 lace@redpoll.ai
学术用途
蕾丝可用于学术用途,请根据 CITATION.cff
元数据引用蕾丝。
依赖项
~1.8–2.6MB
~53K SLoC