2 个版本
0.1.1 | 2020 年 5 月 31 日 |
---|---|
0.1.0 | 2020 年 5 月 29 日 |
#862 in 机器学习
19KB
326 行
shogun-rust
这是一个 Rust 包,用于绑定到 Shogun 机器学习框架。
注意:这个包处于非常早期的开发阶段,并且仅支持 Shogun 库的一小部分。
注意:这是一个 shogun C++ 库的 Rust 包装器,因此内部/API 并非非常 Rust 风格。
有关设计的信息可以在 这里 找到。
构建
假设您已在本地上安装了 shogun-static 以及 spdlog。如果未找到,CMake 将引发错误。
要构建,只需
cargo build
然后从另一个包中
extern crate shogun;
示例
基本 API
use shogun::shogun::{Kernel, Version};
fn main() {
let version = Version::new();
println!("Shogun version {}", version.main_version().unwrap());
// shogun-rust supports Shogun's factory functions
let k = match Kernel::new("GaussianKernel") {
Ok(obj) => obj,
Err(msg) => {
panic!("No can do: {}", msg);
},
};
// also supports put
match k.put("log_width", &1.0) {
Err(msg) => println!("Failed to put value."),
_ => (),
}
// and get
match k.get("log_width") {
Ok(value) => match value.downcast_ref::<f64>() {
Some(fvalue) => println!("GaussianKernel::log_width: {}", fvalue),
None => println!("GaussianKernel::log_width not of type f64"),
},
Err(msg) => panic!("{}", msg),
}
}
训练随机森林
let f_feats_train = File::read_csv("classifier_4class_2d_linear_features_train.dat".to_string())?;
let f_feats_test = File::read_csv("classifier_4class_2d_linear_features_test.dat".to_string())?;
let f_labels_train = File::read_csv("classifier_4class_2d_linear_labels_train.dat".to_string())?;
let f_labels_test = File::read_csv("classifier_4class_2d_linear_labels_test.dat".to_string())?;
let features_train = Features::from_file(&f_feats_train)?;
let features_test = Features::from_file(&f_feats_test)?;
let labels_train = Labels::from_file(&f_labels_train)?;
let labels_test = Labels::from_file(&f_labels_test)?;
let mut rand_forest = Machine::new("RandomForest")?;
let m_vote = CombinationRule::new("MajorityVote")?;
rand_forest.put("labels", &labels_train)?;
rand_forest.put("num_bags", &100)?;
rand_forest.put("combination_rule", &m_vote)?;
rand_forest.put("seed", &1)?;
rand_forest.train(&features_train)?;
let predictions = rand_forest.apply(&features_test)?;
let acc = Evaluation::new("MulticlassAccuracy")?;
rand_forest.put("oob_evaluation_metric", &acc)?;
let accuracy = acc.evaluate(&predictions, &labels_test)?;
println!("Model accuracy: {}", accuracy);
依赖项
~4.5–7.5MB
~149K SLoC