12 个版本
0.1.11 | 2024 年 6 月 26 日 |
---|---|
0.1.10 | 2024 年 6 月 26 日 |
0.1.4 | 2024 年 4 月 17 日 |
0.1.3 | 2024 年 3 月 28 日 |
#154 in 机器学习
47KB
1.5K SLoC
MLflow_rs
这是 MLflow 的实验跟踪客户端库。比官方 Python 库的改进
- 未提交的更改将被正确处理以确保可重复性
- 与 log 兼容的记录器可以与实验结果一起存储
- 如果用户想要终止实验,实验代码会收到通知,这提供了例如完成当前迭代/保存当前状态等的机会。
- 编译时配置
disable_experiment_tracking
禁用实验跟踪并删除大部分代码,从而在需要临时禁用实验跟踪时减少开销
使用方法
[dependencies]
mlflow_rs = "0.1"
use std::{error::Error, sync::{atomic::{AtomicBool, Ordering}, Arc}, thread::sleep, time::Duration};
use env_logger::Builder;
use log::{error, info};
use mlflow_rs::{experiment::Experiment, run::{Run, RunTag}};
/// Function that executes the experiment
fn experiment_function(run: &Run, was_killed: Arc<AtomicBool>) -> Result<(), Box<dyn Error>> {
info!("info message");
error!("error message");
run.log_parameter("learning_rate", "0.001")?;
run.log_metric("metric", 42.0, Some(0))?;
run.log_artifact_bytes("test data".to_owned().into_bytes(), "test.txt")?;
for _ in 0..10 {
if was_killed.load(Ordering::Relaxed) {
return Ok(())
}
sleep(Duration::from_secs(1));
}
u32::from_str_radix("a", 10).unwrap(); // panics, to show that panics get caught and handled
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
let experiment = Experiment::new("https://127.0.0.1:5000", "test")?;
let mut logger_builder = Builder::from_default_env();
let logger = logger_builder.build();
let mut run = experiment.create_run_with_git_diff(
Some("new run"),
vec![RunTag {
key: "tag_name".to_owned(),
value: "tag_value".to_owned(),
}],
)?;
run.run_experiment_with_logger(experiment_function, logger)?;
Ok(())
}
当您想暂时禁用跟踪时
创建文件 .cargo/config.toml
并添加
[build]
rustflags = ["--cfg", "disable_experiment_tracking"]
或使用 cargo 运行
cargo rustc --lib -- --cfg disable_experiment_tracking
或设置 RUSTFLAGS 环境变量
RUSTFLAGS="--cfg disable_experiment_tracking" cargo build --lib
其他方法可以在此处找到: https://doc.rust-lang.net.cn/cargo/reference/config.html?highlight=rustflags#buildrustflags
依赖项
~6–18MB
~286K SLoC