6 个版本

0.0.7 2023 年 11 月 9 日
0.0.6 2023 年 8 月 6 日
0.0.5 2023 年 7 月 27 日
0.0.3 2023 年 6 月 30 日

#39机器学习

Download history 1098/week @ 2024-04-23 696/week @ 2024-04-30 985/week @ 2024-05-07 1059/week @ 2024-05-14 987/week @ 2024-05-21 796/week @ 2024-05-28 528/week @ 2024-06-04 849/week @ 2024-06-11 1155/week @ 2024-06-18 1088/week @ 2024-06-25 404/week @ 2024-07-02 645/week @ 2024-07-09 876/week @ 2024-07-16 1039/week @ 2024-07-23 1011/week @ 2024-07-30 456/week @ 2024-08-06

3,505 每月下载量
用于 12 个crate(6 个直接)

MIT 许可证

140KB
3.5K SLoC

llm-samplers

大型语言模型的标记采样器,用 Rust 编写!

状态

开发初期,测试不佳。您可以查看 src/tests.rs 以查看一些使用示例。

这里还有一个使用 Mirostat 和我的 RWKV 项目的一个相当简单的例子: https://github.com/KerfuffleV2/smolrsrwkv/blob/60b8e8bfe64f157f1800445128af3b4adbbc64c1/smolrwkv-cli/src/main.rs#L139-L164

有关从 0.0.6 迁移到 0.0.7 的说明,请见下文。

采样器

在这里使用“采样器”这个术语较为宽松,未来可能需要重命名。目前,“采样器”可能是某种操作 logits 列表的东西(例如,top-k 采样器可能会将列表修剪到前 K 个条目),它可能实际上会选择一个标记,或者两者都选择!

  1. 平面偏差 - 根据指定量对标记进行偏差
  2. 频率/存在 - 应用频率和存在惩罚
  3. 贪婪 - 选择概率最高的标记 ID
  4. 局部典型
  5. Mirostat V1
  6. Mirostat V2
  7. 随机分布 - 根据加权概率选择标记 ID
  8. 重复 - 应用重复惩罚
  9. 尾部自由
  10. 温度
  11. Top-K
  12. Top-P
  13. Min-P
  14. Top-A

实说明可能(或可能不发生)最终会出现。目前,您可以查看 llama.cpp main 示例 README 以了解一些采样器类型的大致概述: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md#generation-flags

示例

您通常不会想单独使用 Sampler。最典型的用法是将多个采样器链接在一起。

构建 [SamplerChain] 的简单示例

use anyhow::Result;

use llm_samplers::prelude::*;

pub fn test_chain1() -> Result<()> {

    let mut logits = Logits::try_from_iter([0.1f32, 0.2, 0.3, 0.4].into_iter())?;

    // Demonstrating the different ways you can build a SamplerChain.
    // These are all equivalent.
    let mut sc = SamplerChain::new()
        + SampleFlatBias::new([(3, f32::NEG_INFINITY)]);
    sc += SampleTemperature::new(0.8);
    sc.push_sampler(SampleGreedy::new());

    assert_eq!(
        sc.sample_token(
            // These samplers don't actually need any resources.
            &mut NilSamplerResources::default(),
            &mut logits)?,
        Some(1)
    );

    // () also implements HasSamplerResources
    // so you could use &mut () here.
    assert_eq!(sc.sample_token(&mut (), &mut logits)?, Some(1));
    Ok(())
}

前面的示例很简单,但不太现实:贪婪采样器甚至不关心温度。现在让我们看看一个稍微复杂一点的例子

use anyhow::Result;
use rand::{SeedableRng, rngs::StdRng};

use llm_samplers::prelude::*;

fn test_chain2() -> Result<()> {
    let example_logits = vec![0.1f32, 0.2, 0.3, 0.4];
    let mut res = SimpleSamplerResources::new(
        // Optionally include an RNG resource.
        Some(Box::new(StdRng::seed_from_u64(123))),
        // Optionally include a last tokens resource.
        Some(vec![]),
    );
    let mut logits = Logits::try_from_iter(example_logits.into_iter())?;
    let mut logits2 = logits.clone();

    let mut sc = SamplerChain::new()
        // Bias logits (this example sets bias for token id 3 to -inf)
        + SampleFlatBias::new([(3, f32::NEG_INFINITY)])
        // Apply a repetition penalty.
        + SampleRepetition::new(1.1, 64)
        // Apply frequency and presence penalties.
        + SampleFreqPresence::new(0.05, 0.1, 64)
        // Apply temperature to logits.
        + SampleTemperature::new(0.8)
        // Sample a token using Mirostat1
        + SampleMirostat1::new(4, 5.0, 0.1);

    // Put a value into `last_tokens`, this simulates us having already picked
    // that token (3) previously.
    res.with_last_tokens_mut(&mut |tokens| tokens.push(3u32))?;

    assert_eq!(sc.sample_token(&mut res, &mut logits)?, Some(2));

    // Now add the last selected token to the list.
    res.with_last_tokens_mut(&mut |tokens| tokens.push(2u32))?;

    // And pick the next one. *Important*: Note that we don't reuse `logits`.
    // This is because `logits` already has all the filtering/sorting/permutation
    // from the previous sample call applied to it.
    assert_eq!(sc.sample_token(&mut res, &mut logits2)?, Some(1));
    Ok(())
}

从 0.0.6 迁移到 0.0.7

很遗憾,这涉及到一些重大变更。基本上,采样器和链不再接受token id和logits类型变量。您可以将token id设置为任何喜欢的颜色,只要它是u32即可。同样,对于logits:现在始终是f32

例如,之前您会这样做SampleRandDistrib::<u32>::newSampleMirostat2::<u32, f32>::new,现在只需要SampleRandDistrib::newSampleMirostat2::new。对于创建链也是如此:SamplerChain::<u32, f32>::new将只需要SamplerChain::new

注意:Crate/docs版本可能不会与此存储库匹配。

致谢

初始版本与llama.cpp项目中的采样器密切相关(尽管不是逐行移植)。谢谢!

依赖项

~0.7–1.3MB
~27K SLoC