3 个不稳定版本

0.2.1 2022 年 5 月 15 日
0.2.0 2022 年 5 月 14 日
0.1.2 2022 年 5 月 13 日
0.0.0 2022 年 5 月 12 日

机器学习 分类中排名第 294

MIT 许可证

30KB
332

const_cge: 神经网络编译器

Cover graphic for the ecosystem of crates including EANT2 and const_cge

做什么?

Illustration depicts transformation of a neural network into a rust function

const_cge 在编译时对您的神经网络进行符号评估,生成具有相同行为的有效 Rust 代码。

有了网络内部数据依赖的信息,LLVM 能够进行更高级的优化,如指令消除、流水线感知重排序、SIMD 向量化、寄存器和堆栈大小最小化等。

生成的 Rust 代码

  • 永远不会分配、panic 或依赖于 std(除非使用 std 功能!)
  • 具有完美的确定性
  • 具有静态声明的输入和输出维度
  • 具有静态可分析的内部分数据依赖关系
  • 使用精确的最小循环状态数组,或根本不使用(仅支付您使用的费用)
  • 在类型系统中静态捕获神经网络属性
  • 在运行时没有开销成本

查看 eant2 了解如何训练与 const_cge 兼容的神经网络。

const_cge = "0.2"

#![no_std] 中的浮点数

您可以通过功能选择浮点实现:libm(默认)、stdmicromath,例如

const_cge = "0.2" # use libm
const_cge = { version = "0.2", default-features = false, features = ["std"] } # `no_std` incompatible
const_cge = { version = "0.2", default-features = false, features = ["micromath"] } # use micromath

简单示例

网络

network 宏生成评估我们网络所需的所有字段和函数。

/// Use sensor data to control the limbs of a robot (using f32 only).
#[network("nets/walk.cge", numeric_type = f32)]
struct Walk;

let mut walk = Walk::default();

// I/O is statically sized to match our network
walk.evaluate(&input, &mut output);

编译时保证

非循环的

如果网络可以存储有关其过去状态的信息(循环性),则有时会成为一个问题。

您可以使用 nonrecurrent,如果导入的网络包含任何循环,编译将停止。

/// Predict which lighting color would best 
/// complement the current sunlight color
#[nonrecurrent("nets/color.cge")]
struct Color;

// evaluate is now a static function.
// it has no state, and this is captured in the type system.
Color::evaluate(&input, &mut output);

循环

某些任务最好使用循环架构来解决,包含一个非循环网络将是一个错误。

您可以使用 recurrent,如果导入的网络不包含循环,编译将停止。

/// Detect if our device has just been dropped 
/// and is currently falling through the air
#[recurrent("nets/drop.cge")]
struct Dropped;

let mut d = Dropped::default();
d.evaluate(&input, &mut output);

循环状态?

循环状态存储神经元的上一个值,用于下一次评估(在网络中反向传递)。

循环网络中的状态表示为 [f64; N](或 [f32; N]),并在每次评估时更新。正如之前提到的,它的大小仅限于所需的程度。

如果您愿意,可以读取此状态,修改它,稍后恢复等。

/// Attempt to clarify audio stream
#[recurrent("nets/denoise.cge")]
struct Denoise;

// I want a specific recurrent state, 
// not the `::default()` initially-zero recurrent state.
let mut d = Denoise::with_recurrent_state(&saved_state);

// Some evaluations later, read internal state
let state = d.recurrent_state();

// Or modify internal state
do_something_to_state(d.recurrent_state_mut());

// Or set custom state after construction
d.set_recurrent_state(&saved_state);

numeric_type

  • 您通常不需要 f64 的精度,而 f64 通常比 f32 更大、更慢。使用 f64 的行为将与您的 CGE 文件完全相同,因此这是默认行为。
  • 您可以在您的网络中执行(有损)参数 '降级',使所有参数和操作都使用您请求的类型。
#[network("net.cge", numeric_type = f32)]
struct SmallerFaster;
  • 目前仅支持 f64f32。未来可能会添加对 f16 / 整数 / 固定精度支持。

Netcrates!

什么是 netcrate?

  • const_cge netcrates 是作为 crate 预训练的神经网络!

  • const_cge 作为通用格式,允许社区共享用于常见任务的神经网络。

让我们看看如何使用一个!

use netcrate_ocr::ocr;
#[network(ocr)]
struct HandwritingOCR;

发布 netcrate

在您的 Cargo.toml 文件中,

  • 确保为 const_cge 禁用 default-features
  • 并添加一个 std 功能
[dependencies]
const_cge = { version = "0.2", default-features = false } # <== important!

[features]
std = [] # <== important!

在您的 stc/lib.rs 文件中,

  • 确保 条件性地启用 no_std
#![cfg_attr(not(feature = "std"), no_std)]  // <== important!
const_cge::netcrate!(ocr_english  = "nets/ocr/en.cge");
const_cge::netcrate!(ocr_japanese = "nets/ocr/jp.cge");

完成!

扩展

如果您想提供一个更友好的接口来包装您的网络,请编写一个宏,提供实现,如下所示

#[macro_export]
macro_rules! ocr_ext {
  ($name: ident, $numeric_type: ty) => {
    impl $name {
      /// Returns the unicode char
      pub fn predict_char(&mut self, image: &ImageBuffer) -> char {
        // access everything a `const_cge` struct normally has:
        let output_dim = $name::OUTPUT_SIZE;
        self.recurrent_state_mut()[0] *= -1.0;

        // even access the particluar activation function implementation the end
        // user has chosen:
        const_cge::activations::$numeric_type::relu(x);
      }
    }

    // or produce a new struct, whatever you think is best.
    struct SmolOCR {
      network: $name,
      extra_memory_bank: [$numeric_type; 6 * $name::OUTPUT_SIZE]
    }

    impl SmolOCR {
      //...
  }
}

最终用户可以简单地

use netcrate_ocr::*;
#[network(ocr_japanese, numeric_type = f32)]
struct JapaneseOCR;
ocr_ext!(JapaneseOCR, f32);
这种方法是一个必要的恶,因为我们必须允许用户为 `no_std` 环境选择自己的数值后端,这些选项可能会随着时间的推移而演变。编写扩展宏是我想出的最适合这种特定用例的最不糟糕的方法。

设计目标 & 缺陷

  • 您可以用“小”网络完成很多事情,特别是对于控制任务。 const_cge 并不打算用于“深度学习”任务(语言建模等)。
  • 使嵌入式用例(机器人技术、5¢微控制器)成为可能的优势
  • 在同一个二进制文件中包含大量的单个 const_cge 网络可能会比运行时评估方法更大或更慢。这取决于目标机器和您正在评估的网络。如果您真的关心,请进行测量。这个库应该能够完美地覆盖常见用例。

MIT许可

Copyright © 2022 Will Brickner

Permission is hereby granted, free of charge, to any person obtaining a 
copy of this software and associated documentation files (the "Software"), 
to deal in the Software without restriction, including without limitation 
the rights to use, copy, modify, merge, publish, distribute, sublicense, 
and/or sell copies of the Software, and to permit persons to whom the 
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in 
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
DEALINGS IN THE SOFTWARE.

依赖关系

~2MB
~50K SLoC