1个稳定版本
22.5.23 | 2022年5月23日 |
---|---|
22.5.17 |
|
#303 in 机器学习
34KB
517 行
scratch_genetic
描述
用于我的march-madness-predictor项目的从零开始的遗传算法库
API参考
内容
genetic
模块
遗传模块是scratch_genetic库下的唯一公开模块。
它包含了一组函数,实现了“遗传算法”,该算法模仿自然选择的概念,以创建一个可用于预测某些结果的模型。
其工作方式是,您将数据转换为输入和输出位流,根据输入数据的大小和设置创建一系列随机网络,通过运行测试和繁殖函数进行训练,最后在结束时导出最佳网络。
然后您可以
gen_pop
函数
pub async fn gen_pop(
pop_size: usize,
layer_sizes: Vec<usize>, num_inputs: usize, num_outputs: usize,
activation_thresh: f64, trait_swap_chance: f64,
weight_mutate_chance: f64, weight_mutate_amount: f64,
offset_mutate_chance: f64, offset_mutate_amount: f64) -> Vec<Network> {
此函数生成一个基于底层私有结构体Network的随机向量。它是私有的,因为您不需要手动操作它,只需在函数之间传递它即可。
您可以看到它接受相当多的参数。这些都是在您的网络中提供的手动设置,以调整其训练方式。
参数
pop_size
- 要训练的网络数量。越大越好,但越大也越慢layer_sizes
- 包含神经网络每层大小的向量num_inputs
- 转换后的输入数据产生的位数(必须能被8整除)num_outputs
- 输出预期的位数(必须能被8整除)activation_thresh
- 神经元开启的难度trait_swap_chance
- 控制繁殖时子代从每个父代共享不同特质的可变性weight_mutate_chance
- 神经元之间连接权重变化的概率weight_mutate_amount
- 上述变化的强度offset_mutate_chance
和offset_mutate_amount
- 与上述两个相同,但使用连接的基本值
test_and_sort
函数
pubasyncfn test_and_sort(pop: &mut Vec<Network>, data_set: &Vec<(Vec<u8>, Vec<u8>)>) {
此函数接收由 gen_pop
创建的 "population"(Networks 的向量)和您的测试数据,查看每个网络接近重现每个测试数据输出的程度,然后根据该性能对网络进行排序。
reproduce
函数
pubasyncfn reproduce(pop: &mut Vec<Network>) {
在 sorting
之后,您可能希望重现。这将保留您的网络集,保留上半部分,并使用这些网络来替换下半部分,下半部分的网络将共享混合基因和基于您在 gen_pop
中提供的参数的突变。
load_and_predict
函数
pubasyncfn load_and_predict(file_name: &'static str, input_bits: &Vec<u8>) -> Vec<u8> {
加载一个已导出 exported 的模型,并生成输出位,前提是您传递了输入位
export_model
函数
pubasyncfn export_model(file_name: &'static str, pop: &Network) {
将网络导出到文件。
示例
以下示例使用这些常量
// Neuron connection settings
pub const NEURON_ACTIVATION_THRESH: f64 = 0.60;
pub const TRAIT_SWAP_CHANCE: f64 = 0.80;
pub const WEIGHT_MUTATE_CHANCE: f64 = 0.65;
pub const WEIGHT_MUTATE_AMOUNT: f64 = 0.5;
pub const OFFSET_MUTATE_CHANCE: f64 = 0.25;
pub const OFFSET_MUTATE_AMOUNT: f64 = 0.05;
// Neural network settings
pub const LAYER_SIZES: [usize; 4] = [ 8, 32, 32, 16 ];
// Algortithm settings
const POP_SIZE: usize = 2000;
const DATA_FILE_NAME: &'static str = "NCAA Mens March Madness Historical Results.csv";
const MODEL_FILE_NAME: &'static str = "model.mmp";
const NUM_GENS: usize = 1000;
训练
println!("Training new March Madness Predictor Model");
// Custom class that structures CSV data and allows turning into bits.
println!("Loading training data from {}", DATA_FILE_NAME);
let games = GameInfo::collection_from_file(DATA_FILE_NAME);
let games: Vec<(Vec<u8>, Vec<u8>)> = games.iter().map(|game| { // Redefines games
(game.to_input_bits().to_vec(), game.to_output_bits().to_vec())
}).collect();
println!("Generating randomized population");
let now = Instant::now();
let mut pop = gen_pop(
POP_SIZE,
LAYER_SIZES.to_vec(), NUM_INPUTS, NUM_OUTPUTS,
NEURON_ACTIVATION_THRESH, TRAIT_SWAP_CHANCE,
WEIGHT_MUTATE_CHANCE, WEIGHT_MUTATE_AMOUNT,
OFFSET_MUTATE_CHANCE, OFFSET_MUTATE_AMOUNT
).await;
let elapsed = now.elapsed();
println!("Generation took {}s", elapsed.as_secs_f64());
println!("Starting training");
for i in 0..NUM_GENS {
println!("Generation {} / {}", i, NUM_GENS);
test_and_sort(&mut pop, &games).await;
reproduce(&mut pop).await;
}
// Save algorithm
println!("Saving model to {}", MODEL_FILE_NAME);
export_model(MODEL_FILE_NAME, &pop[0]).await;
预测
pub async fn predict(team_names: &str) {
let table_data = team_names.split(",");
let mut indexable_table_data = Vec::new();
for item in table_data {
indexable_table_data.push(item);
}
// A team, A seed, B team, B seed, date, round, region
if indexable_table_data.len() != 7 {
println!("Invalid input string!");
return;
}
// Like the other example, this stuff is converting CSV data into a useable form
println!("Converting input into data...");
let entry = TableEntry {
winner: String::from(indexable_table_data[0]),
win_seed: String::from(indexable_table_data[1]),
loser: String::from(indexable_table_data[2]),
lose_seed: String::from(indexable_table_data[3]),
date: String::from(indexable_table_data[4]),
round: String::from(indexable_table_data[5]),
region: String::from(indexable_table_data[6]),
win_score: String::from("0"),
lose_score: String::from("0"),
overtime: String::from("")
};
let game = GameInfo::from_table_entry(&entry);
// Here's where the code is used
println!("Predicting!");
let result = load_and_predict(MODEL_FILE_NAME, &game.to_input_bits().to_vec()).await;
println!("Predicted score for {}: {}", indexable_table_data[0], result[0]);
println!("Predicted score for {}: {}", indexable_table_data[2], result[1]);
println!("Expected overtimes: {}", result[2]);
}
Dependencies
~3–9.5MB
~77K SLoC