4 个版本 (1 个稳定版)
1.0.0 | 2023 年 5 月 16 日 |
---|---|
0.2.0 | 2023 年 5 月 11 日 |
0.1.1 | 2023 年 5 月 9 日 |
0.1.0 | 2023 年 5 月 8 日 |
#288 in 机器学习
每月 38 次下载
20KB
314 行
tiny ml
一个简单、快速用于基本神经网络的 Rust 库。
这有什么用?
- 学习 ML
- 进化模拟
- 你可以用它来做任何事情
这不是什么?
- 像 Tensorflow 或 PyTorch 这样的大型 ML 库。这是简单和基础的,或者说只是“小型”。
如何使用这个库?
例如,下面是如何创建一个可以判断一个点是否在圆内的模型的方法!
use tiny_ml::prelude::*;
// how many input data-points the model has
const NET_INPUTS: usize = 2;
// how many data-points the model outputs
const NET_OUTPUTS: usize = 1;
// radius of the circle
const RADIUS: f32 = 30.0;
fn main() {
// create a network
let mut net: NeuralNetwork<NET_INPUTS, NET_OUTPUTS> = NeuralNetwork::new()
.add_layer(3, ActivationFunction::ReLU)
.add_layer(3, ActivationFunction::ReLU)
.add_layer(1, ActivationFunction::Linear);
// this network has no weights yet but we can fix that by training it
// for training we first need a dataset
let mut inputs = vec![];
let mut outputs = vec![];
// well just generate some samples
for x in 0..100 {
for y in 0..100 {
inputs.push([x as f32, y as f32]);
// we want this to be a classifier, so we will give -1 for in the circle
// and +1 for in the circle
outputs.push(
if (x as f32).abs() + (y as f32).abs() < RADIUS{
[1.0
} else {
[-1.0]
}
)
}
}
let data = DataSet {
inputs,
outputs,
};
// get ourselves a trainer
let trainer = BasicTrainer::new(data);
// let it train 10 times, 50 iterations each
for _ in 0..10 {
trainer.train(&mut net, 50);
// print the total error, lower is better
println!("{}", trainer.get_total_error(&net))
}
}
功能
serialization
启用 Serde 支持。 parallelization
启用 rayon,默认功能。
速度如何?
以下是在 AMD Ryzen 5 2600X (12) @ 3.6 GHz 上使用 'bench' 示例的一些基准测试。使用 --release
标志进行构建。基准测试在该网络上进行 1000 万次运行,然后汇总结果
fn main() {
let mut net: NeuralNetwork<1, 1> = NeuralNetwork::new()
.add_layer(5, ActivationFunction::ReLU)
.add_layer(5, ActivationFunction::ReLU)
.add_layer(5, ActivationFunction::ReLU)
.add_layer(5, ActivationFunction::ReLU)
.add_layer(5, ActivationFunction::ReLU)
.add_layer(5, ActivationFunction::ReLU)
.add_layer(1, ActivationFunction::Linear);
}
方法 | 时间 | 描述 |
---|---|---|
运行 |
1.045s | 单线程,但缓冲了一些 Vecs |
unbuffered_run |
1.251s | 可以同时运行多个线程,但需要更频繁地分配 |
par_run |
240ms | 一次接收多个输入。使用 rayon 并行化计算,底层使用 unbuffered_run |
依赖项
~245–740KB
~14K SLoC