2 个版本

0.0.6 2023年9月19日
0.0.5 2022年1月29日

#711 in 科学


用于 2 crates

MIT/Apache

115KB
2K SLoC

具有并行采样过程的异步训练器。

代码可能如下所示。

fn train() {
    let agent_configs: Vec<_> = vec![agent_config()];
    let env_config_train = env_config(name);
    let env_config_eval = env_config(name).eval();
    let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?;
    let step_proc_config = SimpleStepProcessorConfig::default();
    let actor_man_config = ActorManagerConfig::default();
    let async_trainer_config = load_async_trainer_config(model_dir.as_str())?;
    let mut recorder = TensorboardRecorder::new(model_dir);
    let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?;

    // Shared flag to stop actor threads
    let stop = Arc::new(Mutex::new(false));

    // Creates channels
    let (item_s, item_r) = unbounded(); // items pushed to replay buffer
    let (model_s, model_r) = unbounded(); // model_info

    // guard for initialization of envs in multiple threads
    let guard_init_env = Arc::new(Mutex::new(true));

    // Actor manager and async trainer
    let mut actors = ActorManager::build(
        &actor_man_config,
        &agent_configs,
        &env_config_train,
        &step_proc_config,
        item_s,
        model_r,
        stop.clone(),
    );
    let mut trainer = AsyncTrainer::build(
        &async_trainer_config,
        &agent_config,
        &env_config_eval,
        &replay_buffer_config,
        item_r,
        model_s,
        stop.clone(),
    );

    // Set the number of threads
    tch::set_num_threads(1);

    // Starts sampling and training
    actors.run(guard_init_env.clone());
    let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env);
    println!("Stats of async trainer");
    println!("{}", stats.fmt());

    let stats = actors.stop_and_join();
    println!("Stats of generated samples in actors");
    println!("{}", actor_stats_fmt(&stats));
}

训练过程由以下两个组件组成

Agent 必须实现 SyncModel 特性,以便将 Actor 中的代理模型与 AsyncTrainer 中训练的代理同步。该特性具有将模型信息作为 SyncModel::ModelInfo 导入和导出的能力。

AgentAsyncTrainer 中负责训练,通常使用 GPU,而 Actor 中的 ActorManager 负责使用 CPU 进行采样。

AsyncTrainerActorManager 都在同一台机器上运行,并通过通道进行通信。

依赖项

~21–30MB
~286K SLoC