#candle #extension #pytorch #function #tensor #ext #devices

candle-ext

Candle 的扩展库,提供 Candle 中尚未提供的 PyTorch 函数

8 个版本

0.1.7 2023 年 12 月 12 日
0.1.6 2023 年 12 月 3 日
0.1.5 2023 年 11 月 15 日

#292 in 机器学习


用于 candlelighter

MIT/Apache

42KB
946

Candle 扩展

Test

Candle 的扩展库,提供 Candle 中尚未提供的 PyTorch 函数

 use candle_ext::{
     candle::{ D, DType, Device, Result, Tensor},
     TensorExt, F,
 };

 fn main() -> Result<()> {
     let device = Device::Cpu;
     let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
     let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

     let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;

     Ok(())
 }

目前提供(另请参阅 测试

  • F::scaled_dot_product_attention

  • F::chunk2..5 / Tensor::chunk2..5

  • F::cumsum / Tensor::cumsum

  • F::equal / Tensor::equal

  • F::eye / Tensor::eye

  • F::full / Tensor::full

  • F::full_like / Tensor::full_like

  • F::triu / Tensor::triu

  • F::tril / Tensor::tril

  • F::masked_fill / Tensor::masked_fill

  • F::logical_not / Tensor::logical_not

  • F::logical_or / Tensor::logical_or

  • F::outer / Tensor::outer

  • F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5

许可证

许可协议为以下之一

任选其一。

贡献

除非你明确声明,否则根据 Apache-2.0 许可协议定义的,任何有意提交以包含在你所做工作的贡献,都将按照上述方式双重许可,无需任何附加条款或条件。

依赖

~8–12MB
~265K SLoC