#机器学习 #绑定 #构建

bin+lib shogun

Shogun 的 Rust 绑定

2 个版本

0.1.1 2020 年 5 月 31 日
0.1.0 2020 年 5 月 29 日

#862 in 机器学习

BSD-3-Clause

19KB
326

shogun-rust

这是一个 Rust 包,用于绑定到 Shogun 机器学习框架。

注意:这个包处于非常早期的开发阶段,并且仅支持 Shogun 库的一小部分。
注意:这是一个 shogun C++ 库的 Rust 包装器,因此内部/API 并非非常 Rust 风格。

有关设计的信息可以在 这里 找到。

构建

假设您已在本地上安装了 shogun-static 以及 spdlog。如果未找到,CMake 将引发错误。

要构建,只需

cargo build

然后从另一个包中

extern crate shogun;

示例

基本 API

use shogun::shogun::{Kernel, Version};

fn main() {
    let version = Version::new();
    println!("Shogun version {}", version.main_version().unwrap());

    // shogun-rust supports Shogun's factory functions
    let k = match Kernel::new("GaussianKernel") {
        Ok(obj) => obj,
        Err(msg) => {
            panic!("No can do: {}", msg);
        },
    };

    // also supports put
    match k.put("log_width", &1.0) {
        Err(msg) => println!("Failed to put value."),
        _ => (),
    }

    // and get
    match k.get("log_width") {
        Ok(value) => match value.downcast_ref::<f64>() {
            Some(fvalue) => println!("GaussianKernel::log_width: {}", fvalue),
            None => println!("GaussianKernel::log_width not of type f64"),
        },
        Err(msg) => panic!("{}", msg),
    }
}

训练随机森林

let f_feats_train = File::read_csv("classifier_4class_2d_linear_features_train.dat".to_string())?;
let f_feats_test = File::read_csv("classifier_4class_2d_linear_features_test.dat".to_string())?;
let f_labels_train = File::read_csv("classifier_4class_2d_linear_labels_train.dat".to_string())?;
let f_labels_test = File::read_csv("classifier_4class_2d_linear_labels_test.dat".to_string())?;

let features_train = Features::from_file(&f_feats_train)?;
let features_test = Features::from_file(&f_feats_test)?;
let labels_train = Labels::from_file(&f_labels_train)?;
let labels_test = Labels::from_file(&f_labels_test)?;

let mut rand_forest = Machine::new("RandomForest")?;
let m_vote = CombinationRule::new("MajorityVote")?;

rand_forest.put("labels", &labels_train)?;
rand_forest.put("num_bags", &100)?;
rand_forest.put("combination_rule", &m_vote)?;
rand_forest.put("seed", &1)?;

rand_forest.train(&features_train)?;

let predictions = rand_forest.apply(&features_test)?;

let acc = Evaluation::new("MulticlassAccuracy")?;
rand_forest.put("oob_evaluation_metric", &acc)?;
let accuracy = acc.evaluate(&predictions, &labels_test)?;

println!("Model accuracy: {}", accuracy);

依赖项

~4.5–7.5MB
~149K SLoC