6 个版本 (3 个重大更新)
0.8.1 | 2023 年 7 月 10 日 |
---|---|
0.8.0 | 2023 年 7 月 10 日 |
0.7.1 | 2023 年 7 月 9 日 |
0.6.0 | 2023 年 7 月 8 日 |
0.5.0 | 2023 年 7 月 4 日 |
#106 in 机器学习
每月 22 次下载
310KB
8K SLoC
Rust 中的低摩擦高细节深度学习框架
使用方法
安装与 cargo 功能
在您的项目 Cargo.toml
文件中添加以下内容,用您想使用的后端替换 <BACKEND>
(见 后端)
[dependencies]
jiro_nn = {
version = "*",
default-features = false,
features = ["<BACKEND>", "data"] # "data" is optional and enables the preprocessing and dataframes features
}
功能 | 描述 | 编译时成本 |
---|---|---|
data (默认功能) |
添加 DataTable ,一个更简单的 polars DataFrame API;启用 Kfolds 训练;添加 preprocessing 模块以创建依赖于数据集配置的管道 |
高 |
parquet |
为与 data 功能相关的所有内容添加 Apache Parquet 文件支持 |
中等 |
ipc |
为与 data 功能相关的所有内容添加 Arrow 文件支持 |
中等 |
ndarray (默认功能) |
将 Matrix 和 Image 类型更改为由 ndarray crate 提供的 CPU-bound 后端。 Image 和卷积操作未使用此后端完全实现,但正在进行中。 |
低 |
nalgebra |
将 Matrix 和 Image 类型更改为由 nalgebra crate 提供的 CPU-bound 后端。 Image 和卷积操作未使用此后端完全实现,可能永远也不会实现。 |
低 |
arrayfire |
将 Matrix 和 Image 类型更改为由 arrayfire crate 提供的 GPU 和 CPU 后端。 理想用于卷积网络。 需要 ArrayFire C++ 库。见 安装 Arrayfire |
安装难度低,但较难安装 |
f64 |
将 Scalar 类型从由 f32 支持改为由 f64 支持 |
None |
裸骨 XOR 示例
使用简单的神经网络预测XOR函数
let x = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let y = vec![
vec![0.0],
vec![1.0],
vec![1.0],
vec![0.0]
];
let network_model = NetworkModelBuilder::new()
.full_dense(3)
.tanh()
.end()
.full_dense(1)
.tanh()
.end()
.build();
let in_size = 2;
let mut network = network_model.to_network(in_size);
let loss = Losses::MSE.to_loss();
for epoch in 0..1000 {
let error = network.train(epoch, &x, &y, &loss, 1);
println!("Epoch: {} Error: {}", epoch, error);
}
预处理 + CNNs 示例
MNIST(手写数字识别)工作流程示例
// Step 1: Enrich the features of your data (eg. the "columns") with metadata using a Dataset configuration
// The configuration is necessary for guiding further steps (preprocessing, training...)
// Extract features from a spreadsheet to start building a dataset configuration
// You could also start blank and add the columns and metadata manually
let mut dataset_config = Dataset::from_file("dataset/train.csv");
// Now we can add metadata to our features
dataset_config
// Flag useless features for removal
.remove_features(&["size"])
// Tell the framework which column is an ID (so it can be ignored in training, used in joins, and so on)
.tag_feature("id", IsId)
// Tell the framework which column is the feature to predict
// You could very well declare multiple features as Predicted
.tag_feature("label", Predicted)
// Since it is a classification problem, indicate the label needs One-Hot encoding during preprocessing
.tag_feature("label", OneHotEncode)
// You may also want to normalize everything except the ID & label during preprocessing
.tag_all(Normalized.except(&["id", "label"]));
// Step 2: Preprocess the data
// Create a pipeline with all the necessary steps
let mut pipeline = Pipeline::basic_single_pass();
// Run it on the data
let (dataset_config, data) = pipeline
.load_data("dataset/train.csv", Some(dataset_config))
.run();
// Step 3: Specify and build your model
// A model is tied to a dataset configuration
let model = ModelBuilder::new(dataset_config)
// Some configuration is also tied to the model
// All the configuration calls are optional, defaults are picked otherwise
.batch_size(128)
.loss(Losses::BCE)
.epochs(20)
// Then you can start building the neural network
.neural_network()
// Specify all your layers
// A convolution network is considered a layer of a neural network in this framework
.conv_network(1)
// Now the convolution layers
.full_dense(32, 5)
// You can set the activation function for any layer and many other parameters
// Otherwise defaults are picked
.relu()
.adam()
.dropout(0.4)
.end()
.avg_pooling(2)
.full_dense(64, 5)
.relu()
.adam()
.dropout(0.5)
.end()
.avg_pooling(2)
.end()
// Now we go back to configuring the top-level neural network
.full_dense(128)
.relu()
.adam()
.end()
.full_dense(10)
.softmax()
.adam()
.end()
.end()
.build();
println!(
"Model parameters count: {}",
model.to_network().get_params().count()
);
// Step 4: Train the model
// Monitor the progress of the training on a nice TUI (with other options coming soon)
TM::start_monitoring();
// Use a SplitTraining to split the data into a training and validation set (k-fold also available)
let mut training = SplitTraining::new(0.8);
let (preds_and_ids, model_eval) = training.run(&model, &data);
TM::stop_monitoring();
// Step 5: Save the resulting predictions, weights and model evaluation
// Save the model evaluation per epoch
model_eval.to_json_file("mnist_eval.json");
// Save the weights
let model_params = training.take_model();
model_params.to_json_file("mnist_weights.json");
// Save the predictions alongside the original data
let preds_and_ids = pipeline.revert(&preds_and_ids);
pipeline
.revert(&data)
.inner_join(&preds_and_ids, "id", "id", Some("pred"))
.to_csv_file("mnist_values_and_preds.csv");
您可以使用第三方crate如 gnuplot
(推荐),plotly
(也推荐)或甚至 plotters
来绘制结果。
对于更深入的示例,包含更多可配置的工作流程和多个脚本,请查看 examples
文件夹。
功能
作为一个框架,它相当有见地,并且具有很多功能。但这里是一些主要的功能
NNs(密集层,全连接层...),CNNs(密集层,直接层,平均池化...),一切批处理,SGD,Adam,动量,Glorot,许多激活函数(Softmax,Tanh,ReLU...),学习率调度,K折,分割训练,可缓存和可回滚的流水线(归一化,特征提取,异常值过滤,值映射,独热编码,对数缩放...),损失函数(二元交叉熵,均方误差),代码化模型构建,代码化预处理配置,性能指标(R²...),任务监控(进度,日志),多后端(CPU,GPU,见 后端),多精度(见 精度)。
范围和目标
主要目标
- 实现足够的算法,使其适用于大多数用例
- 不仅限于NNs,还要实现CNNs,RNNs以及我们可以实现的任何东西
- 处理可以作为“后端”考虑的辅助用例(模型构建,训练,预处理)
- 为核心功能制定有见地的API,并为不简单的支持库提供包装器(DataFrames,线性代数...)
- 在严谨性和错误处理之上简化API(但不是不安全的Rust)
- 使其能够实现工业化和配置工作流程(例如,数据预处理,模型构建,训练,评估...)
- 不要比Python难1000倍,比C++慢10倍(否则还有什么意义?)
辅助/未来目标
- 实现罕见但有趣的算法(直接层,前向层...)
- WebAssembly和WebGPU支持
- 通过wgpu的Rust本地GPU后端(我可以梦想,对吧?)
- 通过PyO3的Python绑定
- 模型构建的图形工具
非目标
- 数据可视化
- 符合其他框架/标准
- 完美的错误处理(不要羞于使用
unwrap
和expect
)
后端
通过Cargo功能切换后端
arrayfire
(CPU/GPU)- ✅ 可视化可用
- ✅ GPU支持
- 🫤 CPU支持较慢
- 🫤 安装难度大(见 安装Arrayfire)
- 🫤 C++库,段错误...
ndarray
- ✅ 最快的CPU后端
- ✅ 纯Rust
- 🫤 可视化不可用(尚不可用)
- 🫤 仅CPU
nalgebra
- ✅ 纯Rust
- 🫤 可视化不可用(并且尚未计划)
- 🫤 仅CPU
精度
您可以使用 f64
功能启用高达 f64
的精度。
不支持低于 f32
的精度(尚不支持)。
安装 Arrayfire
您需要首先 安装Arrayfire,以便使用 arrayfire
功能在CPU或GPU上使用Arrayfire的C++/CUDA/OpenCL后端进行快速计算(如果已安装,它将首先尝试OpenCL,然后是CUDA,然后是C++)。确保安装的所有步骤都能100%正常工作,没有任何奇怪的警告,因为它们可能会以相当微妙的方式失败。
一旦安装了Arrayfire,您
- 将
AF_PATH
设置为您的Arrayfire安装目录(例如:/opt/Arrayfire
)。 - 将lib文件路径添加到环境变量中
- Linux:
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$AF_PATH/lib64
- OSX:
DYLD_LIBRARY_PATH=$DYLD_LIBRARY_PATH:$AF_PATH/lib
- Windows:将
%AF_PATH%\lib
添加到PATH
- Linux:
- 如果是在Linux上,运行
sudo ldconfig
- 运行
cargo clean
- 禁用默认功能并激活
arrayfire
功能
[dependencies]
jiro_nn = {
version = "*",
default_features = false,
features = ["arrayfire"]
}
如果您想在Linux上使用Arrayfire的CUDA功能(已在Windows 11 WSL2与Ubuntu以及RTX 3060上测试过),请查看此指南https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#ubuntu-installation。
依赖项
~5–23MB
~326K SLoC