#classification #machine-learning #xmc #multi-label

bin+lib omikuji

极端多标签分类中分区标签树及其变体的有效实现

14 个版本

0.5.1 2023 年 10 月 30 日
0.5.0 2022 年 2 月 2 日
0.4.1 2021 年 12 月 6 日
0.3.4 2021 年 12 月 4 日
0.1.3 2019 年 8 月 25 日

#99 in 机器学习

Download history 58/week @ 2024-07-29

58 每月下载量

MIT 许可证

130KB
2.5K SLoC

Omikuji

Build Status Crate version PyPI version

这是对分区标签树(Prabhu 等人,2018 年)及其在极端多标签分类中的应用进行了有效实现的程序,是用 Rust 编写的🦀,充满爱心💖。

特性 & 性能

Omikuji 在 极端分类仓库 的数据集上进行了测试。以下所有测试均在四核英特尔® 酷睿™ i7-6700 CPU 上运行,我们尽可能利用尽可能多的核心。我们测量了训练时间,并计算了 1、3 和 5 时的精确度。(请注意,由于随机性,结果可能因运行而异,特别是对于较小的数据集。)

Parabel,更好的并行化

Omikuji 提供了 Parabel(Prabhu 等人,2018 年)的更并行化实现,当有更多 CPU 核心可用时,训练速度更快。与仅能利用与树数量相同数量的 CPU 核心数(默认为 3)的 C++ 编写的 原始实现 相比,Omikuji 在我们的四核机器上保持了相同的精确度,但训练速度快了 1.3 倍到 1.7 倍。如果还有更多 CPU 核心可用,还可以进一步加速

数据集 指标 Parabel Omikuji
(平衡的,
cluster.k=2)
EURLex-4K P@1 82.2 82.1
P@3 68.8 68.8
P@5 57.6 57.7
训练时间 18 秒 14 秒
Amazon-670K P@1 44.9 44.8
P@3 39.8 39.8
P@5 36.0 36.0
训练时间 404 秒 234 秒
WikiLSHTC-325K P@1 65.0 64.8
P@3 43.2 43.1
P@5 32.0 32.1
训练时间 959 秒 659 秒

浅树的正则 k-means

遵循 Bonsai(Khandagale 等人,2019 年),Omikuji 支持在树构建中使用正则 k-means 而不是平衡的 2-means 聚类,这会导致更宽、更浅且不平衡的树,训练速度较慢但精度更高。与 原始 Bonsai 实现 相比,Omikuji 在我们的四核机器上也实现了相同的精确度,同时训练速度快了 2.6 倍到 4.6 倍。(同样,如果还有更多 CPU 核心可用,还可以进一步加速。)

数据集 指标 Bonsai Omikuji
(不平衡,
cluster.k=100,
max_depth=3)
EURLex-4K P@1 82.8 83.0
P@3 69.4 69.5
P@5 58.1 58.3
训练时间 87秒 19秒
Amazon-670K P@1 45.5* 45.6
P@3 40.3* 40.4
P@5 36.5* 36.6
训练时间 5,759秒 1,753秒
WikiLSHTC-325K P@1 66.6* 66.6
P@3 44.5* 44.4
P@5 33.0* 33.0
训练时间 11,156秒 4,259秒

**精度数值如论文中报告的;我们的机器没有足够的内存来使用他们的实现进行完整预测。

平衡k-means用于平衡浅树

有时我们希望拥有既浅又宽且平衡的树,在这种情况下,Omikuji也支持用于聚类的HOMER(Tsoumakas等,2008)使用的平衡k-means算法。

数据集 指标 Omikuji
(平衡的,
cluster.k=100)
EURLex-4K P@1 82.1
P@3 69.4
P@5 58.1
训练时间 19秒
Amazon-670K P@1 45.4
P@3 40.3
P@5 36.5
训练时间 1,153秒
WikiLSHTC-325K P@1 65.6
P@3 43.6
P@5 32.5
训练时间 3,028秒

平衡浅树的层折叠

构建平衡、浅和宽树的另一种方法是折叠相邻的层,类似于AttentionXML(You等,2019)中使用的树压缩步骤:移除中间层,并将它们的子节点作为它们父节点的子节点。例如,使用平衡2-means聚类,如果我们每层折叠5层,我们可以将树的度数从2增加到2^5+1 = 64。

数据集 指标 Omikuji
(平衡的,
cluster.k=2,
折叠5层)
EURLex-4K P@1 82.4
P@3 69.3
P@5 58.0
训练时间 16秒
Amazon-670K P@1 45.3
P@3 40.2
P@5 36.4
训练时间 460秒
WikiLSHTC-325K P@1 64.9
P@3 43.3
P@5 32.3
训练时间 1,649秒

构建 & 安装

Omikuji可以用Cargo作为CLI应用程序轻松构建和安装

cargo install omikuji --features cli --locked

或从最新源安装

cargo install --git https://github.com/tomtung/omikuji.git --features cli --locked

CLI应用程序将作为omikuji可用。例如,要在EURLex-4K数据集上重现结果

omikuji train eurlex_train.txt --model_path ./model
omikuji test ./model eurlex_test.txt --out_path predictions.txt

Python绑定

还提供简单的Python绑定,用于训练和预测。可以通过pip安装

pip install omikuji

请注意,如果需要编译,您可能还需要安装Cargo。

您也可以从最新源安装

pip install git+https://github.com/tomtung/omikuji.git -v

以下脚本演示了如何使用Python绑定来训练模型并进行预测

import omikuji

# Train
hyper_param = omikuji.Model.default_hyper_param()
# Adjust hyper-parameters as needed
hyper_param.n_trees = 5
model = omikuji.Model.train_on_data("./eurlex_train.txt", hyper_param)

# Serialize & de-serialize
model.save("./model")
model = omikuji.Model.load("./model")
# Optionally densify model weights to trade off between prediction speed and memory usage
model.densify_weights(0.05)

# Predict
feature_value_pairs = [
    (0, 0.101468),
    (1, 0.554374),
    (2, 0.235760),
    (3, 0.065255),
    (8, 0.152305),
    (10, 0.155051),
    # ...
]
label_score_pairs =  model.predict(feature_value_pairs)

用法

$ omikuji train --help
Train a new omikuji model

USAGE:
    omikuji train [OPTIONS] <TRAINING_DATA_PATH>

ARGS:
    <TRAINING_DATA_PATH>
            Path to training dataset file

            The dataset file is expected to be in the format of the Extreme Classification
            Repository.

OPTIONS:
        --centroid_threshold <THRESHOLD>
            Threshold for pruning label centroid vectors

            [default: 0]

        --cluster.eps <CLUSTER_EPS>
            Epsilon value for determining linear classifier convergence

            [default: 0.0001]

        --cluster.k <K>
            Number of clusters

            [default: 2]

        --cluster.min_size <MIN_SIZE>
            Labels in clusters with sizes smaller than this threshold are reassigned to other
            clusters instead

            [default: 2]

        --cluster.unbalanced
            Perform regular k-means clustering instead of balanced k-means clustering

        --collapse_every_n_layers <N_LAYERS>
            Number of adjacent layers to collapse

            This increases tree arity and decreases tree depth.

            [default: 0]

    -h, --help
            Print help information

        --linear.c <C>
            Cost coefficient for regularizing linear classifiers

            [default: 1]

        --linear.eps <LINEAR_EPS>
            Epsilon value for determining linear classifier convergence

            [default: 0.1]

        --linear.loss <LOSS>
            Loss function used by linear classifiers

            [default: hinge]
            [possible values: hinge, log]

        --linear.max_iter <M>
            Max number of iterations for training each linear classifier

            [default: 20]

        --linear.weight_threshold <MIN_WEIGHT>
            Threshold for pruning weight vectors of linear classifiers

            [default: 0.1]

        --max_depth <DEPTH>
            Maximum tree depth

            [default: 20]

        --min_branch_size <SIZE>
            Number of labels below which no further clustering & branching is done

            [default: 100]

        --model_path <MODEL_PATH>
            Optional path of the directory where the trained model will be saved if provided

            If an model with compatible settings is already saved in the given directory, the newly
            trained trees will be added to the existing model")

        --n_threads <N_THREADS>
            Number of worker threads

            If 0, the number is selected automatically.

            [default: 0]

        --n_trees <N_TREES>
            Number of trees

            [default: 3]

        --train_trees_1_by_1
            Finish training each tree before start training the next

            This limits initial parallelization but saves memory.

        --tree_structure_only
            Build the trees without training classifiers

            Might be useful when a downstream user needs the tree structures only.
$ omikuji test --help
Test an existing omikuji model

USAGE:
    omikuji test [OPTIONS] <MODEL_PATH> <TEST_DATA_PATH>

ARGS:
    <MODEL_PATH>
            Path of the directory where the trained model is saved

    <TEST_DATA_PATH>
            Path to test dataset file

            The dataset file is expected to be in the format of the Extreme Classification
            Repository.

OPTIONS:
        --beam_size <BEAM_SIZE>
            Beam size for beam search

            [default: 10]

    -h, --help
            Print help information

        --k_top <K>
            Number of top predictions to write out for each test example

            [default: 5]

        --max_sparse_density <DENSITY>
            Density threshold above which sparse weight vectors are converted to dense format

            Lower values speed up prediction at the cost of more memory usage.

            [default: 0.1]

        --n_threads <N_THREADS>
            Number of worker threads

            If 0, the number is selected automatically.

            [default: 0]

        --out_path <OUT_PATH>
            Path to the which predictions will be written, if provided

数据格式

我们的实现接受格式与Extreme Classification Repository中提供的格式相同的数据集文件。数据文件以包含三个用空格分隔的整数的标题行开始:示例总数、特征数和标签数。在标题行之后,每行有一个示例,以逗号分隔的标签开头,后跟用空格分隔的特征:值对

label1,label2,...labelk ft1:ft1_val ft2:ft2_val ft3:ft3_val .. ftd:ftd_val

趣闻

项目名称来自o-mikuji(御神签),这是在日本的神社和寺庙中写在纸条上的关于个人未来的预测(标签?),通常在阅读后系在松树枝上。

参考文献

  • Y. Prabhu,A. Kag,S. Harsola,R. Agrawal,和M. Varma,“Parabel:用于动态搜索广告的极端分类的分区标签树”,2018年世界 Wide Web 会议论文集,2018年,第993–1002页。
  • S. Khandagale,H. Xiao,和R. Babbar,“Bonsai - 用于极端多标签分类的多样化和浅树”,2019年4月。
  • G. Tsoumakas,I. Katakis,和I. Vlahavas,“在具有大量标签的领域中有效的多标签分类”,ECML,2008。
  • R. You,S. Dai,Z. Zhang,H. Mamitsuka,和S. Zhu,“AttentionXML:基于多标签注意力的循环神经网络进行极端多标签文本分类”,2019年6月。

许可证

Omikuji遵循MIT许可证。

依赖关系

~7–19MB
~215K SLoC