2 个版本

0.0.6 2023年9月19日
0.0.5 2022年1月29日

#969 in 科学


border 中使用

GPL-2.0-or-later

115KB
2.5K SLoC

atari-env 的薄包装,用于 Border

由于在 crates.io 中注册的 crate 没有实现 atari_env::AtariEnv::lives() 方法,该方法对于周期性生命环境是必需的,因此代码在 [atari_env] 下进行了修改。

此环境对观测值应用了一些预处理,如 atari_wrapper.py 中所述。

您需要将 Atari Rom 目录放置在由环境变量 ATARI_ROM_DIR 指定的目录下。一种简单的方法是使用 AutoROM Python 包。

pip install autorom
mkdir $HOME/atari_rom
AutoROM --install-dir $HOME/atari_rom
export ATARI_ROM_DIR=$HOME/atari_rom

以下是一个使用随机策略运行 Pong 环境的示例。

use anyhow::Result;
use border_atari_env::{
    BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig,
    BorderAtariObs, BorderAtariObsRawFilter,
};
use border_core::{util, Env as _, Policy, DefaultEvaluator, Evaluator as _};

type Obs = BorderAtariObs;
type Act = BorderAtariAct;
type ObsFilter = BorderAtariObsRawFilter<Obs>;
type ActFilter = BorderAtariActRawFilter<Act>;
type EnvConfig = BorderAtariEnvConfig<Obs, Act, ObsFilter, ActFilter>;
type Env = BorderAtariEnv<Obs, Act, ObsFilter, ActFilter>;

#[derive(Clone)]
struct RandomPolicyConfig {
    pub n_acts: usize,
}

struct RandomPolicy {
    n_acts: usize,
}

impl Policy<Env> for RandomPolicy {
    type Config = RandomPolicyConfig;

    fn build(config: Self::Config) -> Self {
        Self {
            n_acts: config.n_acts,
        }
    }

    fn sample(&mut self, _: &Obs) -> Act {
        fastrand::u8(..self.n_acts as u8).into()
    }
}

fn env_config(name: String) -> EnvConfig {
    EnvConfig::default().name(name)
}

fn main() -> Result<()> {
    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
    fastrand::seed(42);

    // Creates Pong environment
    let env_config = env_config("pong".to_string());

    // Creates a random policy
    let n_acts = 4; // number of actions;
    let policy_config = RandomPolicyConfig {
        n_acts: n_acts as _,
    };
    let mut policy = RandomPolicy::build(policy_config);

    // Runs evaluation
    let env_config = env_config.render(true);
    let _ = DefaultEvaluator::new(&env_config, 0, 5)?.evaluate(&mut policy);

    Ok(())
}

依赖项

~31–50MB
~887K SLoC