#candle #pytorch #safetensors #mpnet

mpnet-rs

这是将 MPNet 从 PyTorch 转换为 Rust Candle 的翻译。

4 个版本

0.1.3 2024 年 4 月 2 日
0.1.2 2024 年 3 月 8 日
0.1.1 2024 年 2 月 24 日
0.1.0 2024 年 2 月 9 日

#268Cargo 插件

MIT/Apache

36KB
561

mpnet-rs

这是什么?

这是将 MPNet 从 PyTorch 转换为 Rust Candle 的翻译。

  • 我使用的训练模型是 PatentSBERTa,它旨在获取针对专利领域的优化嵌入。
  • 训练流程尚未准备。
  • 如果您有自己的 MPNet 权重,可以使用此卡加载它们。

更新

v.0.1.3

  • 一些数据类型已更改:get_embeddings(), get_embeddings_parallel()

v.0.1.2

  • candle 版本升级:0.3.3 -> 0.4.1

v.0.1.1

  • get_embeddings() 的并行版本:get_embedding_parallel()

如何使用

获取训练模型

  • huggingface 下载模型
  • Candle v0.4.0 支持直接加载 pytorch_model.bin,但 v0.3.3 不支持。
  • 如果您想从 .safetensors 加载模型,您必须自行转换。 此实现 可能有所帮助。

加载模型和权重

use mpnet_rs::mpnet::load_model;
let (model, tokenizer, pooler) = load_model("/path/to/model/and/tokenizer").unwrap();

获取嵌入(带池化器):参见下面的测试函数

这是关于如何获取嵌入和余弦相似度的

use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{VarBuilder,  Module};

use mpnet_rs::mpnet::{MPNetEmbeddings, MPNetConfig, create_position_ids_from_input_ids, cumsum, load_model, get_embeddings, normalize_l2, PoolingConfig, MPNetPooler};


fn test_get_embeddings() ->Result<()>{
    let path_to_checkpoints_folder = "D:/RustWorkspace/checkpoints/AI-Growth-Lab_PatentSBERTa".to_string();

    let (model, mut tokenizer, pooler) = load_model(path_to_checkpoints_folder).unwrap();

    let sentences = vec![
        "an invention that targets GLP-1",
        "new chemical that targets glucagon like peptide-1 ",
        "de novo chemical that targets GLP-1",
        "invention about GLP-1 receptor",
        "new chemical synthesis for glp-1 inhibitors",
        "It feels like I'm in America",
        "It's rainy. all day long.",
    ];
    let n_sentences = sentences.len();
    let embeddings = get_embeddings(&model, &tokenizer, Some(&pooler), &sentences).unwrap();

    let l2norm_embeds = normalize_l2(&embeddings).unwrap();
    println!("pooled embeddings {:?}", l2norm_embeds.shape());

    let mut similarities = vec![];
    for i in 0..n_sentences {
        let e_i = l2norm_embeds.get(i)?;
        for j in (i + 1)..n_sentences {
            let e_j = l2norm_embeds.get(j)?;
            let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
            let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
            let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
            let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
            similarities.push((cosine_similarity, i, j))
        }
    }
    similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
    for &(score, i, j) in similarities[..5].iter() {
        println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
    }

    Ok(())
}

注意

池化层

  • 在 Transformers 中的原始 PyTorch 实现中,池化层在 MPNetModel 类中声明。
  • 我已经独立实现了池化层,将其从 MPNetModel 类中分离出来。

激活函数

  • 在原始实现中,tanh 被用作池化层的激活函数。
  • 然而,由于很难找到 Candle 中的 tanh 实现,我已将其设置为默认的 gelu。

参考文献

依赖项

~23MB
~465K SLoC