1 个不稳定版本

0.1.0 2023 年 7 月 4 日

698机器学习

MIT 许可证

19KB
259

TreeRustler

TreeRustler 是使用 Rust 编程语言实现的决策树分类器的一个简单实现。该项目在构建决策树分类器的同时,可以作为 Rust 语言入门的探索。该项目中的主要模块是 data 模块和 tree 模块。

数据模块

data 模块提供 Data 结构体,它表示树实现所需的数据类型。该 Data 结构体保存构建决策树分类器所需的训练特征。该模块仅处理决策树(即,目前还没有加载、预处理或分割)所必需的数据操作。

树模块

tree 模块包含 DecisionTreeClassifier 结构体。这个结构体表示决策树分类器,并提供以下参数

  • max_depth:指定决策树的最大深度。它控制树在训练过程中可以长多深。设置较小的值可以帮助防止过拟合,但太小可能导致欠拟合。
  • min_samples_split:指定在训练过程中分裂内部节点所需的样本的最小数量。它控制何时停止进一步分裂节点。设置较高的值可以防止过拟合,但太高可能导致欠拟合。

DecisionTreeClassifier 结构体有以下方法

  • fit(x: &Data, y: &Vec<u8>):将决策树分类器拟合到提供的训练数据。此方法使用来自 Data 结构体的特征和来自不同向量的标签来训练分类器。
  • predict_proba(x: &Data) -> Vec<f64>:使用训练好的决策树预测提供数据的类别概率。它返回每个类别的概率向量。

用法

使用TreeRustler项目,请按照以下步骤操作:

  1. 克隆仓库: git clone https://github.com/EduardoPach/treerustler.git
  2. 导航到项目目录: cd treerustler
  3. 请确保您已安装Rust。如果没有,请从https://rust-lang.net.cn/安装Rust。
  4. 加载您的数据,并将特征数据转换为Data结构,标签转换为Vec<u8>
  5. tree模块创建一个DecisionTreeClassifier实例,指定所需的max_depthmin_samples_split值。
  6. 在分类器实例上调用fit方法,传递您的训练数据。
  7. 使用predict_proba方法预测新数据点的类别概率。
use treerustler::data::Data;
use treerustler::tree::DecisionTreeClassifier;

fn main() {
    // Load your features data and labels
    let x: data::Data = data::Data::from_string(&"1 3; 2 3; 3 1; 3 1; 2 3");
    let y: Vec<u8> = vec![0, 0, 1, 1, 2];

    // Create an instance of DecisionTreeClassifier
    let mut classifier = DecisionTreeClassifier::new(max_depth, min_samples_split);

    // Fit the classifier to the training data
    classifier.fit(&x, &y);

    // Predict class probabilities for new data points
    let probabilities = classifier.predict_proba(&x);

    // Process the predicted probabilities as needed
}

依赖项

~310KB