6 个版本 (3 个重大更新)
0.6.0 | 2024 年 8 月 4 日 |
---|---|
0.5.0 | 2024 年 5 月 4 日 |
0.4.0 | 2024 年 2 月 28 日 |
0.3.2 | 2024 年 1 月 7 日 |
0.3.1 | 2023 年 12 月 20 日 |
#83 在 机器学习 中
每月 243 次下载
165KB
3K SLoC
烛光优化器
一个用于烛光,极简机器学习框架的优化器 crate
实现的优化器包括
-
SGD(包括动量和权重衰减)
-
RMSprop
自适应方法
-
AdaDelta
-
AdaGrad
-
AdaMax
-
Adam
-
AdamW(作为
decoupled_weight_decay
包含在 Adam 中) -
NAdam
-
RAdam
这些优化器都与 PyTorch 的实现进行了核对(见 pytorch_test.ipynb),应该实现相同的功能(尽管没有进行一些输入检查)。
此外,所有列出的自适应方法和 SGD 都实现了与 PyTorch 中实现的权重衰减相同的解耦权重衰减,以及 PyTorch 中实现的标准的权重衰减。
伪二阶方法
- LBFGS
这不是与 PyTorch 等效的实现,但在 2D rosenbrock 函数上进行了检查
示例
有一个 mnist 玩具程序以及 adagrad 的简单示例。虽然每种方法的参数都没有调整(所有默认学习率由用户输入),但以下收敛相当好
cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
为了更快地训练,尝试
cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
使用 cuda 后端。
用法
cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers
待定
当前未实现 PyTorch 的功能
-
SparseAdam(不确定如何在烛光中处理稀疏张量)
-
ASGD(没有伪代码)
-
Rprop(需要用张量重新表述)
注意
为了开发,跟踪 PyTorch 方法的状态,请使用
print(optimiser.state)
依赖关系
~9–18MB
~286K SLoC