2个版本
0.1.1 | 2021年4月8日 |
---|---|
0.1.0 | 2021年4月8日 |
#15 in #artificial
5KB
61 行
Mushin:编译时创建神经网络
Mushin 是一种在武术中使用的日本术语,指的是通过练习获得的思维方式。在此阶段,一个人不依赖他们认为应该是下一步的行动,而是依赖他们训练有素的自然反应(或本能)。
描述
Mushin 允许开发者在编译时构建神经网络,具有预分配的具有良好定义大小的数组。这主要有三个非常重要的好处
- 编译时网络一致性检查:您神经网络中的任何缺陷(例如,层输入/输出不匹配)都会在编译时引发。当您的网络推理或训练过程从未失败时,您可以尽情享受咖啡!
- 出色的Rust编译器优化:由于神经网络完全在编译时定义,因此编译器能够执行智能优化,例如展开循环或注入 SIMD 指令。
- 支持嵌入式:构建神经网络不需要
std
库,因此它可以在 Rust 支持的任何目标上运行。
用法
将此添加到您的 Cargo.toml
[dependencies]
mushin = "0.1"
mushin_derive = "0.1"
这是一个非常简单的示例,供您入门
use rand::distributions::Uniform;
use mushin::{activations::ReLu, layers::Dense, NeuralNetwork};
use mushin_derive::NeuralNetwork;
// Builds a neural network with 2 inputs and 1 output
// Made of 3 feed forward layers, you can have as many as you want and with any name
#[derive(NeuralNetwork, Debug)]
struct MyNetwork {
// LayerType<ActivationType, # inputs, # outputs>
input: Dense<ReLu, 2, 4>,
hidden: Dense<ReLu, 4, 2>,
output: Dense<ReLu, 2, 1>,
}
impl MyNetwork {
// Initialize layer weights with a uniform distribution and set ReLU as activation function
fn new() -> Self {
let mut rng = rand::thread_rng();
let dist = Uniform::from(-1.0..=1.0);
MyNetwork {
input: Dense::random(&mut rng, &dist),
hidden: Dense::random(&mut rng, &dist),
output: Dense::random(&mut rng, &dist),
}
}
}
fn main() {
// Init the weights and perform a forward pass
let nn = MyNetwork::new();
println!("{:#?}", nn);
let input = [0.0, 1.0];
println!("Input: {:#?}", input);
let output = nn.forward(input);
println!("Output: {:#?}", output);
}
您可能会想知道 forward
方法是如何工作的。对于这个特定的示例,NeuralNetwork
derive 宏为您定义了它,它看起来像这样
fn forward(&self, input: [f32; 2]) -> [f32; 1] {
self.output.forward(self.hidden.forward(self.input.forward[input]))
}
请注意,forward 方法期望两个输入值,因为这是第一个(input
)层期望的,并且返回一个单一值,因为这是最后一个层(output
)返回的。
路线图
- 编译时神经网络一致性检查
- 文档、CI/CD & Benchmarks
- 反向传递
- 更多层类型(卷积、dropout、lstm...)
- 更多激活函数(sigmoid、softmax...)
- 也许,CPU和/或GPU并发
贡献
如果您发现漏洞、错误或希望添加新功能,请创建新问题。
要将您的更改引入代码库,请提交拉取请求。
非常感谢!
许可证
Mushin 在 MIT 许可证和 Apache 许可证(版本 2.0)的条款下分发。
有关详细信息,请参阅 LICENSE-APACHE、LICENSE-MIT 和 COPYRIGHT。
依赖项
~1.5MB
~35K SLoC