7个版本 (破坏性更新)

0.7.0 2023年5月1日
0.6.0 2023年4月2日
0.5.0 2023年3月8日
0.4.0 2023年3月6日
0.1.0 2023年2月21日

#415 in 机器学习

每月21次下载
用于 learnwell

MIT 许可证

81KB
1.5K SLoC

runnt (rust neural net)

非常简单的全连接神经网络。
当你只想用最少的依赖和代码行数拼凑一些东西时。
目标是创建一个全连接网络,在数据上运行,用大约10行代码得到结果
由于找不到一个没有外部依赖且易于使用的rust库,所以创建了此库。

如果您发现任何错误或优化,欢迎提出问题或PR。

功能

  • 全连接神经网络
  • 最小依赖
  • 不需要外部静态库/dlls
  • 回归和分类
  • 可以定义层的大小
  • 可以定义激活类型
  • 可以保存/加载模型
  • 随机、小批量、梯度下降
  • 正则化
  • 数据集管理器
    • csv
    • onehot编码
    • 归一化
  • 报告

如何使用

简单示例

你需要的是NN和数据

   //XOR
    use runnt::{nn::NN,activation::ActivationType};
    let inputs = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]];
    let outputs = [[0.], [1.], [1.], [0.]];

    let mut nn = NN::new(&[2, 8, 1])
        .with_learning_rate(0.2)
        .with_hidden_type(ActivationType::Tanh)
        .with_output_type(ActivationType::Linear);

    for i in 0..5000 {
        nn.fit_one(&inputs[i % 4], &outputs[i % 4]);
    }

带数据集和报告的简单示例

Dataset使得加载数据和转换数据变得容易一些
train使得运行时代和报告变得容易
在< 10行代码内完成完整的神经网络并报告

let set = Dataset::builder()
    .read_csv("examples/data/iris.csv")
    .add_input_columns(&[0, 1, 2, 3], Conversion::NormaliseMean)
    .add_target_columns(&[4], Conversion::OneHot)
    .allocate_to_test_data(0.2)
    .build();

    let mut net = NN::new(&[set.input_size(), 32, set.target_size()]).with_learning_rate(0.15);
    net.train(&set, 1000, 8, 100, ReportAccuracy::CorrectClassification);

带有数据集、报告和保存

let set = Dataset::builder()
        .read_csv(r"/temp/diamonds.csv")
        .allocate_to_test_data(0.2)
        .add_input_columns(&[0, 4, 5, 7, 8, 9], Conversion::NormaliseMean)
        .add_input_columns(&[1, 2, 3], Conversion::OneHot)
        .add_target_columns(
            &[6],
            Conversion::Function(|f| f.parse::<f32>().unwrap_or_default() / 1_000.),
        )
        .build();

    let save_path = r"network.txt";
    let mut net = if std::path::PathBuf::from_str(save_path).unwrap().exists() {
        NN::load(save_path)
    } else {
        NN::new(&[set.input_size(), 32, set.target_size()])
    };
    //run for 100 epochs, with batch size 32 and report every 10 epochs
    net.train(&set,  100, 32, 10, ReportAccuracy::RSquared);
    net.save(save_path);

依赖关系

~1.5MB
~26K SLoC