8 个版本
0.1.7 | 2023 年 12 月 12 日 |
---|---|
0.1.6 | 2023 年 12 月 3 日 |
0.1.5 | 2023 年 11 月 15 日 |
#292 in 机器学习
42KB
946 行
Candle 扩展
Candle 的扩展库,提供 Candle 中尚未提供的 PyTorch 函数
use candle_ext::{
candle::{ D, DType, Device, Result, Tensor},
TensorExt, F,
};
fn main() -> Result<()> {
let device = Device::Cpu;
let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;
let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;
Ok(())
}
目前提供(另请参阅 测试)
-
F::scaled_dot_product_attention
-
F::chunk2..5 / Tensor::chunk2..5
-
F::cumsum / Tensor::cumsum
-
F::equal / Tensor::equal
-
F::eye / Tensor::eye
-
F::full / Tensor::full
-
F::full_like / Tensor::full_like
-
F::triu / Tensor::triu
-
F::tril / Tensor::tril
-
F::masked_fill / Tensor::masked_fill
-
F::logical_not / Tensor::logical_not
-
F::logical_or / Tensor::logical_or
-
F::outer / Tensor::outer
-
F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5
许可证
许可协议为以下之一
- Apache License, Version 2.0, (LICENSE-APACHE 或 https://apache.ac.cn/licenses/LICENSE-2.0)
- MIT 许可协议 (LICENSE-MIT 或 https://opensource.org/licenses/MIT)
任选其一。
贡献
除非你明确声明,否则根据 Apache-2.0 许可协议定义的,任何有意提交以包含在你所做工作的贡献,都将按照上述方式双重许可,无需任何附加条款或条件。
依赖
~8–12MB
~265K SLoC