2 个版本
0.0.6 | 2023年9月19日 |
---|---|
0.0.5 | 2022年1月29日 |
#711 in 科学
用于 2 crates
115KB
2K SLoC
具有并行采样过程的异步训练器。
代码可能如下所示。
fn train() {
let agent_configs: Vec<_> = vec![agent_config()];
let env_config_train = env_config(name);
let env_config_eval = env_config(name).eval();
let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?;
let step_proc_config = SimpleStepProcessorConfig::default();
let actor_man_config = ActorManagerConfig::default();
let async_trainer_config = load_async_trainer_config(model_dir.as_str())?;
let mut recorder = TensorboardRecorder::new(model_dir);
let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?;
// Shared flag to stop actor threads
let stop = Arc::new(Mutex::new(false));
// Creates channels
let (item_s, item_r) = unbounded(); // items pushed to replay buffer
let (model_s, model_r) = unbounded(); // model_info
// guard for initialization of envs in multiple threads
let guard_init_env = Arc::new(Mutex::new(true));
// Actor manager and async trainer
let mut actors = ActorManager::build(
&actor_man_config,
&agent_configs,
&env_config_train,
&step_proc_config,
item_s,
model_r,
stop.clone(),
);
let mut trainer = AsyncTrainer::build(
&async_trainer_config,
&agent_config,
&env_config_eval,
&replay_buffer_config,
item_r,
model_s,
stop.clone(),
);
// Set the number of threads
tch::set_num_threads(1);
// Starts sampling and training
actors.run(guard_init_env.clone());
let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env);
println!("Stats of async trainer");
println!("{}", stats.fmt());
let stats = actors.stop_and_join();
println!("Stats of generated samples in actors");
println!("{}", actor_stats_fmt(&stats));
}
训练过程由以下两个组件组成
ActorManager
管理Actor
,每个Agent
运行一个线程与Env
交互并获取样本。这些样本将被发送到AsyncTrainer
的重放缓冲区中。AsyncTrainer
负责代理的训练。它还运行一个线程将来自ActorManager
的样本推送到重放缓冲区。
Agent
必须实现 SyncModel
特性,以便将 Actor
中的代理模型与 AsyncTrainer
中训练的代理同步。该特性具有将模型信息作为 SyncModel
::ModelInfo
导入和导出的能力。
Agent
在 AsyncTrainer
中负责训练,通常使用 GPU,而 Actor
中的 ActorManager
负责使用 CPU 进行采样。
AsyncTrainer
和 ActorManager
都在同一台机器上运行,并通过通道进行通信。
依赖项
~21–30MB
~286K SLoC