3个版本 (重大更新)
使用旧的Rust 2015
0.3.0 | 2019年3月9日 |
---|---|
0.2.0 | 2018年1月25日 |
0.1.0 | 2018年1月14日 |
#5 in #monte
每月下载 53 次
46KB
1K SLoC
mcts
这是一个Rust中的蒙特卡洛树搜索(MCTS)库。
实现并行且无锁。泛型设计允许它在多个领域中使用。它可以容纳不同的搜索方法(例如,基于rollout或基于神经网络的叶子评估)。
lib.rs
:
这是一个蒙特卡洛树搜索库。
仍在开发中,文档不够完善。然而,以下示例可能有所帮助
use mcts::*;
use mcts::tree_policy::*;
use mcts::transposition_table::*;
// A really simple game. There's one player and one number. In each move the player can
// increase or decrease the number. The player's score is the number.
// The game ends when the number reaches 100.
//
// The best strategy is to increase the number at every step.
#[derive(Clone, Debug, PartialEq)]
struct CountingGame(i64);
#[derive(Clone, Debug, PartialEq)]
enum Move {
Add, Sub
}
impl GameState for CountingGame {
type Move = Move;
type Player = ();
type MoveList = Vec<Move>;
fn current_player(&self) -> Self::Player {
()
}
fn available_moves(&self) -> Vec<Move> {
let x = self.0;
if x == 100 {
vec![]
} else {
vec![Move::Add, Move::Sub]
}
}
fn make_move(&mut self, mov: &Self::Move) {
match *mov {
Move::Add => self.0 += 1,
Move::Sub => self.0 -= 1,
}
}
}
impl TranspositionHash for CountingGame {
fn hash(&self) -> u64 {
self.0 as u64
}
}
struct MyEvaluator;
impl Evaluator<MyMCTS> for MyEvaluator {
type StateEvaluation = i64;
fn evaluate_new_state(&self, state: &CountingGame, moves: &Vec<Move>,
_: Option<SearchHandle<MyMCTS>>)
-> (Vec<()>, i64) {
(vec![(); moves.len()], state.0)
}
fn interpret_evaluation_for_player(&self, evaln: &i64, _player: &()) -> i64 {
*evaln
}
fn evaluate_existing_state(&self, _: &CountingGame, evaln: &i64, _: SearchHandle<MyMCTS>) -> i64 {
*evaln
}
}
#[derive(Default)]
struct MyMCTS;
impl MCTS for MyMCTS {
type State = CountingGame;
type Eval = MyEvaluator;
type NodeData = ();
type ExtraThreadData = ();
type TreePolicy = UCTPolicy;
type TranspositionTable = ApproxTable<Self>;
fn cycle_behaviour(&self) -> CycleBehaviour<Self> {
CycleBehaviour::UseCurrentEvalWhenCycleDetected
}
}
let game = CountingGame(0);
let mut mcts = MCTSManager::new(game, MyMCTS, MyEvaluator, UCTPolicy::new(0.5),
ApproxTable::new(1024));
mcts.playout_n_parallel(10000, 4); // 10000 playouts, 4 search threads
mcts.tree().debug_moves();
assert_eq!(mcts.best_move().unwrap(), Move::Add);
assert_eq!(mcts.principal_variation(50),
vec![Move::Add; 50]);
assert_eq!(mcts.principal_variation_states(5),
vec![
CountingGame(0),
CountingGame(1),
CountingGame(2),
CountingGame(3),
CountingGame(4),
CountingGame(5)]);
依赖项
~455–640KB