#machine-learning #neural-network #machine #learning #network

tsuga

Rust中构建神经网络的早期机器学习库

2个版本

0.1.1 2020年9月20日
0.1.0 2020年8月24日

#677 in 机器学习

MIT 许可证

38KB
566

crates.io Documentation

Tsuga

Rust中的早期机器学习库

Tsuga是一个Rust中的早期机器学习库,用于构建神经网络。它使用ndarray作为线性代数后端,主要在二维f32数组(Array2<f32>类型)上操作。目前,它的主要功能是为API的各种想法进行测试,作为一种教育练习,可能还不适合用于严肃的应用。到目前为止,项目的大部分重点都在图像处理领域,尽管工具和布局通常也适用于更高/更低维度的数据集。

要将tsuga用作库,请将以下内容添加到您的Cargo.toml文件中

[dependencies]
tsuga = "0.1"
ndarray = "0.13"

对于开发,我建议只克隆最新版本--训练数据已包含在过去的提交中,这可能导致整个历史记录文件大小不必要地增大。这可以通过以下方式完成

$ git clone --depth=1 https://github.com/quietlychris/tsuga.git

用于MNIST的完全连接网络示例

Tsuga目前使用Builder模式来构建完全连接网络。由于网络是复杂的复合结构,这种模式有助于使网络布局明确且模块化。

以下是一个简化的代码示例,用于构建一个网络来训练/评估MNIST(或Fashion MNIST)数据集。包括解包MNIST二进制文件,这个网络在3.65秒内达到1000次迭代的准确率约为91.5%,在29.43秒内达到10,000次迭代的准确率约为97.1%。

  • 在3.65秒内进行1000次迭代,准确率约为91.5%
  • 在29.43秒内进行10,000次迭代,准确率约为97.1%

可以使用以下命令运行此示例:$ cargo run --release --example mnist

use ndarray::prelude::*;
use tsuga::prelude::*;

fn main() {
    // Builds the MNIST data from a binary into ndarray Array2<f32> structures
    // Labels are built with one-hot encoding format
    // ([60_000, 784], [60_000, 10], [10_000, 784], [10_000, 10] )
    let (input, output, test_input, test_output) = mnist_as_ndarray();
    println!("Successfully unpacked the MNIST dataset into Array2<f32> format!");

    let mut layers_cfg: Vec<FCLayer> = Vec::new();
    let sigmoid_layer_0 = FCLayer::new("sigmoid", 128);
    layers_cfg.push(sigmoid_layer_0);
    let sigmoid_layer_1 = FCLayer::new("sigmoid", 64);
    layers_cfg.push(sigmoid_layer_1);

    let mut fcn = FullyConnectedNetwork::default(input, output)
        .add_layers(layers_cfg)
        .iterations(1000)
        .learnrate(0.01)
        .batch_size(200)
        .build();

    fcn.train();

    println!("Test input shape = {:?}", test_input.shape());
    println!("Test output shape = {:?}", test_output.shape());

    let test_result = fcn.evaluate(test_input);
    compare_results(test_result, test_output);
}

依赖项

Tsuga使用minifb在开发期间显示样本图像,这意味着您可能需要通过以下方式添加某些依赖项

$ sudo apt install libxkbcommon-dev libwayland-cursor0 libwayland-dev

依赖项

~25–34MB
~363K SLoC