1 个不稳定版本
0.1.0 | 2023 年 7 月 4 日 |
---|
698 在 机器学习
19KB
259 行
TreeRustler
TreeRustler 是使用 Rust 编程语言实现的决策树分类器的一个简单实现。该项目在构建决策树分类器的同时,可以作为 Rust 语言入门的探索。该项目中的主要模块是 data
模块和 tree
模块。
数据模块
data
模块提供 Data
结构体,它表示树实现所需的数据类型。该 Data
结构体保存构建决策树分类器所需的训练特征。该模块仅处理决策树(即,目前还没有加载、预处理或分割)所必需的数据操作。
树模块
tree
模块包含 DecisionTreeClassifier
结构体。这个结构体表示决策树分类器,并提供以下参数
max_depth
:指定决策树的最大深度。它控制树在训练过程中可以长多深。设置较小的值可以帮助防止过拟合,但太小可能导致欠拟合。min_samples_split
:指定在训练过程中分裂内部节点所需的样本的最小数量。它控制何时停止进一步分裂节点。设置较高的值可以防止过拟合,但太高可能导致欠拟合。
DecisionTreeClassifier
结构体有以下方法
fit(x: &Data, y: &Vec<u8>)
:将决策树分类器拟合到提供的训练数据。此方法使用来自Data
结构体的特征和来自不同向量的标签来训练分类器。predict_proba(x: &Data) -> Vec<f64>
:使用训练好的决策树预测提供数据的类别概率。它返回每个类别的概率向量。
用法
使用TreeRustler项目,请按照以下步骤操作:
- 克隆仓库:
git clone https://github.com/EduardoPach/treerustler.git
- 导航到项目目录:
cd treerustler
- 请确保您已安装Rust。如果没有,请从https://rust-lang.net.cn/安装Rust。
- 加载您的数据,并将特征数据转换为
Data
结构,标签转换为Vec<u8>
。 - 从
tree
模块创建一个DecisionTreeClassifier
实例,指定所需的max_depth
和min_samples_split
值。 - 在分类器实例上调用
fit
方法,传递您的训练数据。 - 使用
predict_proba
方法预测新数据点的类别概率。
use treerustler::data::Data;
use treerustler::tree::DecisionTreeClassifier;
fn main() {
// Load your features data and labels
let x: data::Data = data::Data::from_string(&"1 3; 2 3; 3 1; 3 1; 2 3");
let y: Vec<u8> = vec![0, 0, 1, 1, 2];
// Create an instance of DecisionTreeClassifier
let mut classifier = DecisionTreeClassifier::new(max_depth, min_samples_split);
// Fit the classifier to the training data
classifier.fit(&x, &y);
// Predict class probabilities for new data points
let probabilities = classifier.predict_proba(&x);
// Process the predicted probabilities as needed
}
依赖项
~310KB