2个版本
0.1.1 | 2020年9月20日 |
---|---|
0.1.0 | 2020年8月24日 |
#677 in 机器学习
38KB
566 行
Tsuga
Rust中的早期机器学习库
Tsuga是一个Rust中的早期机器学习库,用于构建神经网络。它使用ndarray
作为线性代数后端,主要在二维f32
数组(Array2<f32>
类型)上操作。目前,它的主要功能是为API的各种想法进行测试,作为一种教育练习,可能还不适合用于严肃的应用。到目前为止,项目的大部分重点都在图像处理领域,尽管工具和布局通常也适用于更高/更低维度的数据集。
要将tsuga
用作库,请将以下内容添加到您的Cargo.toml
文件中
[dependencies]
tsuga = "0.1"
ndarray = "0.13"
对于开发,我建议只克隆最新版本--训练数据已包含在过去的提交中,这可能导致整个历史记录文件大小不必要地增大。这可以通过以下方式完成
$ git clone --depth=1 https://github.com/quietlychris/tsuga.git
用于MNIST的完全连接网络示例
Tsuga目前使用Builder模式来构建完全连接网络。由于网络是复杂的复合结构,这种模式有助于使网络布局明确且模块化。
以下是一个简化的代码示例,用于构建一个网络来训练/评估MNIST(或Fashion MNIST)数据集。包括解包MNIST二进制文件,这个网络在3.65秒内达到1000次迭代的准确率约为91.5%,在29.43秒内达到10,000次迭代的准确率约为97.1%。
- 在3.65秒内进行1000次迭代,准确率约为91.5%
- 在29.43秒内进行10,000次迭代,准确率约为97.1%
可以使用以下命令运行此示例:$ cargo run --release --example mnist
use ndarray::prelude::*;
use tsuga::prelude::*;
fn main() {
// Builds the MNIST data from a binary into ndarray Array2<f32> structures
// Labels are built with one-hot encoding format
// ([60_000, 784], [60_000, 10], [10_000, 784], [10_000, 10] )
let (input, output, test_input, test_output) = mnist_as_ndarray();
println!("Successfully unpacked the MNIST dataset into Array2<f32> format!");
let mut layers_cfg: Vec<FCLayer> = Vec::new();
let sigmoid_layer_0 = FCLayer::new("sigmoid", 128);
layers_cfg.push(sigmoid_layer_0);
let sigmoid_layer_1 = FCLayer::new("sigmoid", 64);
layers_cfg.push(sigmoid_layer_1);
let mut fcn = FullyConnectedNetwork::default(input, output)
.add_layers(layers_cfg)
.iterations(1000)
.learnrate(0.01)
.batch_size(200)
.build();
fcn.train();
println!("Test input shape = {:?}", test_input.shape());
println!("Test output shape = {:?}", test_output.shape());
let test_result = fcn.evaluate(test_input);
compare_results(test_result, test_output);
}
依赖项
Tsuga使用minifb
在开发期间显示样本图像,这意味着您可能需要通过以下方式添加某些依赖项
$ sudo apt install libxkbcommon-dev libwayland-cursor0 libwayland-dev
依赖项
~25–34MB
~363K SLoC