6个版本
0.3.0 | 2024年7月11日 |
---|---|
0.2.2 | 2022年9月26日 |
0.2.1 | 2021年11月16日 |
0.2.0 | 2021年10月17日 |
0.1.1 | 2021年7月27日 |
#304 在 算法
119 每月下载量
195KB
5K SLoC
LIBMF Rust
LIBMF - 大规模稀疏矩阵分解 - 适用于Rust
查看 Disco 以获取高级协同过滤
安装
将以下行添加到应用程序的 Cargo.toml
中的 [dependencies]
libmf = "0.3"
入门指南
将数据准备成以下格式 row_index, column_index, value
let mut data = libmf::Matrix::new();
data.push(0, 0, 5.0);
data.push(0, 2, 3.5);
data.push(1, 1, 4.0);
拟合模型
let model = libmf::Model::params().fit(&data).unwrap();
进行预测
model.predict(row_index, column_index);
获取潜在因子(这些近似训练矩阵)
model.p(row_index);
model.q(column_index);
// or
model.p_iter();
model.q_iter();
获取偏差(训练矩阵中所有元素的平均值)
model.bias();
将模型保存到文件
model.save("model.txt").unwrap();
从文件加载模型
let model = libmf::Model::load("model.txt").unwrap();
传递验证集
let model = libmf::Model::params().fit_eval(&train_set, &eval_set).unwrap();
交叉验证
执行交叉验证
let avg_error = libmf::Model::params().cv(&data, 5).unwrap();
参数
设置参数 - 默认值如下
libmf::Model::params()
.loss(libmf::Loss::RealL2) // loss function
.factors(8) // number of latent factors
.threads(12) // number of threads
.bins(25) // number of bins
.iterations(20) // number of iterations
.lambda_p1(0.0) // L1-regularization parameter for P
.lambda_p2(0.1) // L2-regularization parameter for P
.lambda_q1(0.0) // L1-regularization parameter for Q
.lambda_q2(0.1) // L2-regularization parameter for Q
.learning_rate(0.1) // learning rate
.alpha(1.0) // importance of negative entries
.c(0.0001) // desired value of negative entries
.nmf(false) // perform non-negative MF (NMF)
.quiet(false); // no outputs to stdout
损失函数
对于实值矩阵分解
Loss::RealL2
- 平方误差(L2范数)Loss::RealL1
- 绝对误差(L1范数)Loss::RealKL
- 广义KL散度
对于二值矩阵分解
Loss::BinaryLog
- 对数误差Loss::BinaryL2
- 平方折耳损失Loss::BinaryL1
- 折耳损失
对于单类矩阵分解
Loss::OneClassRow
- 行向对数损失Loss::OneClassCol
- 列向对数损失Loss::OneClassL2
- 平方误差(L2范数)
度量
计算RMSE(用于实值MF)
model.rmse(&data);
计算MAE(用于实值MF)
model.mae(&data);
计算广义KL散度(用于非负实值MF)
model.gkl(&data);
计算对数损失(用于二值MF)
model.logloss(&data);
计算准确率(用于二值MF)
model.accuracy(&data);
计算 MPR(一类 MF 用)
model.mpr(&data, transpose);
计算 AUC(一类 MF 用)
model.auc(&data, transpose);
示例
将这些行添加到您的应用程序的 Cargo.toml
下的 [dependencies]
csv = "1"
serde = { version = "1", features = ["derive"] }
并使用
use csv::ReaderBuilder;
use serde::Deserialize;
use std::fs::File;
#[derive(Debug, Deserialize)]
struct Row {
user_id: i32,
item_id: i32,
rating: f32,
time: i32,
}
fn main() {
let mut train_set = libmf::Matrix::new();
let mut valid_set = libmf::Matrix::new();
let file = File::open("u.data").unwrap();
let mut rdr = ReaderBuilder::new()
.has_headers(false)
.delimiter(b'\t')
.from_reader(file);
for (i, record) in rdr.records().enumerate() {
let row: Row = record.unwrap().deserialize(None).unwrap();
let matrix = if i < 80000 { &mut train_set } else { &mut valid_set };
matrix.push(row.user_id, row.item_id, row.rating);
}
let model = libmf::Model::params().fit_eval(&train_set, &valid_set).unwrap();
println!("RMSE: {:?}", model.rmse(&valid_set));
}
参考
指定矩阵的初始容量
let mut data = libmf::Matrix::with_capacity(3);
资源
历史
查看 变更日志
贡献
鼓励每个人帮助改进这个项目。以下是一些您可以提供帮助的方式
开始开发
git clone --recursive https://github.com/ankane/libmf-rust.git
cd libmf-rust
cargo test
依赖项
~185KB