#实验 #跟踪 #客户端 #记录器 #Python #结果 #m-lflow

bin+lib mlflow_rs

MLflow 的实验跟踪客户端库

12 个版本

0.1.11 2024 年 6 月 26 日
0.1.10 2024 年 6 月 26 日
0.1.4 2024 年 4 月 17 日
0.1.3 2024 年 3 月 28 日

#154 in 机器学习

Apache-2.0

47KB
1.5K SLoC

MLflow_rs

这是 MLflow 的实验跟踪客户端库。比官方 Python 库的改进

  • 未提交的更改将被正确处理以确保可重复性
  • log 兼容的记录器可以与实验结果一起存储
  • 如果用户想要终止实验,实验代码会收到通知,这提供了例如完成当前迭代/保存当前状态等的机会。
  • 编译时配置 disable_experiment_tracking 禁用实验跟踪并删除大部分代码,从而在需要临时禁用实验跟踪时减少开销

使用方法

[dependencies]
mlflow_rs = "0.1"
use std::{error::Error, sync::{atomic::{AtomicBool, Ordering}, Arc}, thread::sleep, time::Duration};

use env_logger::Builder;
use log::{error, info};
use mlflow_rs::{experiment::Experiment, run::{Run, RunTag}};

/// Function that executes the experiment
fn experiment_function(run: &Run, was_killed: Arc<AtomicBool>) -> Result<(), Box<dyn Error>> {
    info!("info message");
    error!("error message");

    run.log_parameter("learning_rate", "0.001")?;
    run.log_metric("metric", 42.0, Some(0))?;
    run.log_artifact_bytes("test data".to_owned().into_bytes(), "test.txt")?;

    for _ in 0..10 {
        if was_killed.load(Ordering::Relaxed) {
            return Ok(())
        }

        sleep(Duration::from_secs(1));
    }

    u32::from_str_radix("a", 10).unwrap();  // panics, to show that panics get caught and handled

    Ok(())
}

fn main() -> Result<(), Box<dyn Error>> {
    let experiment = Experiment::new("https://127.0.0.1:5000", "test")?;

    let mut logger_builder = Builder::from_default_env();
    let logger = logger_builder.build();

    let mut run = experiment.create_run_with_git_diff(
        Some("new run"),
        vec![RunTag {
            key: "tag_name".to_owned(),
            value: "tag_value".to_owned(),
        }],
    )?;

    run.run_experiment_with_logger(experiment_function, logger)?;

    Ok(())
}

当您想暂时禁用跟踪时

创建文件 .cargo/config.toml 并添加

[build]
rustflags = ["--cfg", "disable_experiment_tracking"]

或使用 cargo 运行

cargo rustc --lib -- --cfg disable_experiment_tracking

或设置 RUSTFLAGS 环境变量

RUSTFLAGS="--cfg disable_experiment_tracking" cargo build --lib

其他方法可以在此处找到: https://doc.rust-lang.net.cn/cargo/reference/config.html?highlight=rustflags#buildrustflags

依赖项

~6–18MB
~286K SLoC