#data-loader #machine-learning #tensorflow #ai #pytorch

ai-dataloader

Rust实现的PyTorch DataLoader

12个不稳定版本 (4个破坏性更新)

0.6.1 2023年8月13日
0.6.0 2023年5月22日
0.5.4 2023年4月29日
0.4.0 2023年3月16日
0.2.1 2022年9月27日

#293 in 机器学习

每月40次下载

MIT/Apache

175KB
2.5K SLoC

CI Crates.io Documentation

ai-dataloader

pytorch dataloader库的Rust端口。

注意:该项目仍在积极开发中,处于早期阶段。

亮点

  • 可迭代或可索引(映射风格)的DataLoader
  • 可自定义的SamplerBatchSamplercollate_fn
  • 使用rayon的并行数据加载器(用于可索引数据加载器,实验性)。
  • ndarraytch-rs集成,支持CPU和GPU。
  • 默认的合并函数可以自动合并大多数类型(支持嵌套)。
  • 对可迭代和可索引的DataLoader进行洗牌。

更多信息请参阅文档

示例

示例可以在示例文件夹中找到,但这里有一个简单的示例

use ai_dataloader::DataLoader;
let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();

for (label, text) in &loader {     
    println!("Label {label:?}");
    println!("Text {text:?}");
}

tch-rs集成

为了将您的数据合并成可以在GPU上运行的torch张量,您必须激活tch特性。

此特性依赖于tch crate提供的对C++ libTorch API的绑定。需要libtorch库,可以通过自动或手动方式下载。以下提供如何设置您的环境以使用这些绑定,请参阅tch以获取详细信息或支持。

下一特性

以下特性可能在未来添加

  • 具有替换的RandomSampler
  • 可迭代数据集的并行dataloader
  • 分布式dataloader

MSRV

当前MSRV为1.60。

依赖关系

~3.5-6MB
~120K SLoC