#遗传算法 #训练 #模型 #网络 #创建 #输出 #march-madness-predictor

scratch_genetic

用于我的march-madness-predictor项目的从零开始的遗传算法库

1个稳定版本

22.5.23 2022年5月23日
22.5.17 2022年5月16日

#303 in 机器学习

GPL-3.0 许可证

34KB
517

scratch_genetic

描述

用于我的march-madness-predictor项目的从零开始的遗传算法库

API参考

内容

  1. genetic 模块
    1. gen_pop 函数
    2. test_and_sort 函数
    3. reproduce 函数
    4. load_and_predict 函数
    5. export_model 函数
  2. 示例
    1. 训练
    2. 预测

genetic 模块

遗传模块是scratch_genetic库下的唯一公开模块。

它包含了一组函数,实现了“遗传算法”,该算法模仿自然选择的概念,以创建一个可用于预测某些结果的模型。

其工作方式是,您将数据转换为输入和输出位流,根据输入数据的大小和设置创建一系列随机网络,通过运行测试和繁殖函数进行训练,最后在结束时导出最佳网络。

然后您可以

gen_pop 函数

pub async fn gen_pop(
        pop_size: usize,
        layer_sizes: Vec<usize>, num_inputs: usize, num_outputs: usize,
        activation_thresh: f64, trait_swap_chance: f64,
        weight_mutate_chance: f64, weight_mutate_amount: f64,
        offset_mutate_chance: f64, offset_mutate_amount: f64) -> Vec<Network> {

此函数生成一个基于底层私有结构体Network的随机向量。它是私有的,因为您不需要手动操作它,只需在函数之间传递它即可。

您可以看到它接受相当多的参数。这些都是在您的网络中提供的手动设置,以调整其训练方式。

参数

  • pop_size - 要训练的网络数量。越大越好,但越大也越慢
  • layer_sizes - 包含神经网络每层大小的向量
  • num_inputs - 转换后的输入数据产生的位数(必须能被8整除)
  • num_outputs - 输出预期的位数(必须能被8整除)
  • activation_thresh - 神经元开启的难度
  • trait_swap_chance - 控制繁殖时子代从每个父代共享不同特质的可变性
  • weight_mutate_chance - 神经元之间连接权重变化的概率
  • weight_mutate_amount - 上述变化的强度
  • offset_mutate_chanceoffset_mutate_amount - 与上述两个相同,但使用连接的基本值

test_and_sort 函数

pubasyncfn test_and_sort(pop: &mut Vec<Network>, data_set: &Vec<(Vec<u8>, Vec<u8>)>) {

此函数接收由 gen_pop 创建的 "population"(Networks 的向量)和您的测试数据,查看每个网络接近重现每个测试数据输出的程度,然后根据该性能对网络进行排序。

reproduce 函数

pubasyncfn reproduce(pop: &mut Vec<Network>) {

sorting 之后,您可能希望重现。这将保留您的网络集,保留上半部分,并使用这些网络来替换下半部分,下半部分的网络将共享混合基因和基于您在 gen_pop 中提供的参数的突变。

load_and_predict 函数

pubasyncfn load_and_predict(file_name: &'static str, input_bits: &Vec<u8>) -> Vec<u8> {

加载一个已导出 exported 的模型,并生成输出位,前提是您传递了输入位

export_model 函数

pubasyncfn export_model(file_name: &'static str, pop: &Network) {

将网络导出到文件。

示例

以下示例使用这些常量

// Neuron connection settings
pub const NEURON_ACTIVATION_THRESH: f64 = 0.60;
pub const TRAIT_SWAP_CHANCE: f64 = 0.80;
pub const WEIGHT_MUTATE_CHANCE: f64 = 0.65;
pub const WEIGHT_MUTATE_AMOUNT: f64 = 0.5;
pub const OFFSET_MUTATE_CHANCE: f64 = 0.25;
pub const OFFSET_MUTATE_AMOUNT: f64 = 0.05;

// Neural network settings
pub const LAYER_SIZES: [usize; 4] = [ 8, 32, 32, 16 ];

// Algortithm settings
const POP_SIZE: usize = 2000;

const DATA_FILE_NAME: &'static str = "NCAA Mens March Madness Historical Results.csv";
const MODEL_FILE_NAME: &'static str = "model.mmp";
const NUM_GENS: usize = 1000;

训练

println!("Training new March Madness Predictor Model");

// Custom class that structures CSV data and allows turning into bits.
println!("Loading training data from {}", DATA_FILE_NAME);
let games = GameInfo::collection_from_file(DATA_FILE_NAME);
let games: Vec<(Vec<u8>, Vec<u8>)> = games.iter().map(|game| { // Redefines games
    (game.to_input_bits().to_vec(), game.to_output_bits().to_vec())
}).collect();

println!("Generating randomized population");
let now = Instant::now();
let mut pop = gen_pop(
    POP_SIZE,
    LAYER_SIZES.to_vec(), NUM_INPUTS, NUM_OUTPUTS,
    NEURON_ACTIVATION_THRESH, TRAIT_SWAP_CHANCE,
    WEIGHT_MUTATE_CHANCE, WEIGHT_MUTATE_AMOUNT,
    OFFSET_MUTATE_CHANCE, OFFSET_MUTATE_AMOUNT
).await;
let elapsed = now.elapsed();
println!("Generation took {}s", elapsed.as_secs_f64());

println!("Starting training");
for i in 0..NUM_GENS {
    println!("Generation {} / {}", i, NUM_GENS);
    test_and_sort(&mut pop, &games).await;
    reproduce(&mut pop).await;
}

// Save algorithm
println!("Saving model to {}", MODEL_FILE_NAME);
export_model(MODEL_FILE_NAME, &pop[0]).await;

预测

pub async fn predict(team_names: &str) {
    let table_data = team_names.split(",");
    let mut indexable_table_data = Vec::new();
    for item in table_data {
        indexable_table_data.push(item);
    }
    
    // A team, A seed, B team, B seed, date, round, region
    if indexable_table_data.len() != 7 {
        println!("Invalid input string!");
        return;
    }

    // Like the other example, this stuff is converting CSV data into a useable form
    println!("Converting input into data...");
    let entry = TableEntry {
        winner: String::from(indexable_table_data[0]),
        win_seed: String::from(indexable_table_data[1]),
        loser: String::from(indexable_table_data[2]),
        lose_seed: String::from(indexable_table_data[3]),
        date: String::from(indexable_table_data[4]),
        round: String::from(indexable_table_data[5]),
        region: String::from(indexable_table_data[6]),

        win_score: String::from("0"),
        lose_score: String::from("0"),
        overtime: String::from("")
    };
    let game = GameInfo::from_table_entry(&entry);

    // Here's where the code is used
    println!("Predicting!");
    let result = load_and_predict(MODEL_FILE_NAME, &game.to_input_bits().to_vec()).await;

    println!("Predicted score for {}: {}", indexable_table_data[0], result[0]);
    println!("Predicted score for {}: {}", indexable_table_data[2], result[1]);
    println!("Expected overtimes: {}", result[2]);
}

Dependencies

~3–9.5MB
~77K SLoC