6 个版本
0.0.7 | 2023 年 11 月 9 日 |
---|---|
0.0.6 | 2023 年 8 月 6 日 |
0.0.5 | 2023 年 7 月 27 日 |
0.0.3 | 2023 年 6 月 30 日 |
#39 在 机器学习 中
3,505 每月下载量
用于 12 个crate(6 个直接)
140KB
3.5K SLoC
llm-samplers
大型语言模型的标记采样器,用 Rust 编写!
状态
开发初期,测试不佳。您可以查看 src/tests.rs
以查看一些使用示例。
这里还有一个使用 Mirostat 和我的 RWKV 项目的一个相当简单的例子: https://github.com/KerfuffleV2/smolrsrwkv/blob/60b8e8bfe64f157f1800445128af3b4adbbc64c1/smolrwkv-cli/src/main.rs#L139-L164
有关从 0.0.6 迁移到 0.0.7 的说明,请见下文。
采样器
在这里使用“采样器”这个术语较为宽松,未来可能需要重命名。目前,“采样器”可能是某种操作 logits 列表的东西(例如,top-k 采样器可能会将列表修剪到前 K 个条目),它可能实际上会选择一个标记,或者两者都选择!
- 平面偏差 - 根据指定量对标记进行偏差
- 频率/存在 - 应用频率和存在惩罚
- 贪婪 - 选择概率最高的标记 ID
- 局部典型
- Mirostat V1
- Mirostat V2
- 随机分布 - 根据加权概率选择标记 ID
- 重复 - 应用重复惩罚
- 尾部自由
- 温度
- Top-K
- Top-P
- Min-P
- Top-A
实说明可能(或可能不发生)最终会出现。目前,您可以查看 llama.cpp main
示例 README 以了解一些采样器类型的大致概述: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md#generation-flags
示例
您通常不会想单独使用 Sampler
。最典型的用法是将多个采样器链接在一起。
构建 [SamplerChain] 的简单示例
use anyhow::Result;
use llm_samplers::prelude::*;
pub fn test_chain1() -> Result<()> {
let mut logits = Logits::try_from_iter([0.1f32, 0.2, 0.3, 0.4].into_iter())?;
// Demonstrating the different ways you can build a SamplerChain.
// These are all equivalent.
let mut sc = SamplerChain::new()
+ SampleFlatBias::new([(3, f32::NEG_INFINITY)]);
sc += SampleTemperature::new(0.8);
sc.push_sampler(SampleGreedy::new());
assert_eq!(
sc.sample_token(
// These samplers don't actually need any resources.
&mut NilSamplerResources::default(),
&mut logits)?,
Some(1)
);
// () also implements HasSamplerResources
// so you could use &mut () here.
assert_eq!(sc.sample_token(&mut (), &mut logits)?, Some(1));
Ok(())
}
前面的示例很简单,但不太现实:贪婪采样器甚至不关心温度。现在让我们看看一个稍微复杂一点的例子
use anyhow::Result;
use rand::{SeedableRng, rngs::StdRng};
use llm_samplers::prelude::*;
fn test_chain2() -> Result<()> {
let example_logits = vec![0.1f32, 0.2, 0.3, 0.4];
let mut res = SimpleSamplerResources::new(
// Optionally include an RNG resource.
Some(Box::new(StdRng::seed_from_u64(123))),
// Optionally include a last tokens resource.
Some(vec![]),
);
let mut logits = Logits::try_from_iter(example_logits.into_iter())?;
let mut logits2 = logits.clone();
let mut sc = SamplerChain::new()
// Bias logits (this example sets bias for token id 3 to -inf)
+ SampleFlatBias::new([(3, f32::NEG_INFINITY)])
// Apply a repetition penalty.
+ SampleRepetition::new(1.1, 64)
// Apply frequency and presence penalties.
+ SampleFreqPresence::new(0.05, 0.1, 64)
// Apply temperature to logits.
+ SampleTemperature::new(0.8)
// Sample a token using Mirostat1
+ SampleMirostat1::new(4, 5.0, 0.1);
// Put a value into `last_tokens`, this simulates us having already picked
// that token (3) previously.
res.with_last_tokens_mut(&mut |tokens| tokens.push(3u32))?;
assert_eq!(sc.sample_token(&mut res, &mut logits)?, Some(2));
// Now add the last selected token to the list.
res.with_last_tokens_mut(&mut |tokens| tokens.push(2u32))?;
// And pick the next one. *Important*: Note that we don't reuse `logits`.
// This is because `logits` already has all the filtering/sorting/permutation
// from the previous sample call applied to it.
assert_eq!(sc.sample_token(&mut res, &mut logits2)?, Some(1));
Ok(())
}
从 0.0.6 迁移到 0.0.7
很遗憾,这涉及到一些重大变更。基本上,采样器和链不再接受token id和logits类型变量。您可以将token id设置为任何喜欢的颜色,只要它是u32
即可。同样,对于logits:现在始终是f32
。
例如,之前您会这样做SampleRandDistrib::<u32>::new
或SampleMirostat2::<u32, f32>::new
,现在只需要SampleRandDistrib::new
,SampleMirostat2::new
。对于创建链也是如此:SamplerChain::<u32, f32>::new
将只需要SamplerChain::new
。
链接
注意:Crate/docs版本可能不会与此存储库匹配。
致谢
初始版本与llama.cpp项目中的采样器密切相关(尽管不是逐行移植)。谢谢!
依赖项
~0.7–1.3MB
~27K SLoC