12个不稳定版本 (4个破坏性更新)
0.6.1 | 2023年8月13日 |
---|---|
0.6.0 | 2023年5月22日 |
0.5.4 | 2023年4月29日 |
0.4.0 | 2023年3月16日 |
0.2.1 | 2022年9月27日 |
#293 in 机器学习
每月40次下载
175KB
2.5K SLoC
ai-dataloader
是pytorch
dataloader
库的Rust端口。
注意:该项目仍在积极开发中,处于早期阶段。
亮点
- 可迭代或可索引(映射风格)的
DataLoader
。 - 可自定义的
Sampler
、BatchSampler
和collate_fn
。 - 使用
rayon
的并行数据加载器(用于可索引数据加载器,实验性)。 - 与
ndarray
和tch-rs
集成,支持CPU和GPU。 - 默认的合并函数可以自动合并大多数类型(支持嵌套)。
- 对可迭代和可索引的
DataLoader
进行洗牌。
更多信息请参阅文档。
示例
示例可以在示例文件夹中找到,但这里有一个简单的示例
use ai_dataloader::DataLoader;
let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();
for (label, text) in &loader {
println!("Label {label:?}");
println!("Text {text:?}");
}
tch-rs
集成
为了将您的数据合并成可以在GPU上运行的torch张量,您必须激活tch
特性。
此特性依赖于tch crate提供的对C++ libTorch
API的绑定。需要libtorch
库,可以通过自动或手动方式下载。以下提供如何设置您的环境以使用这些绑定,请参阅tch以获取详细信息或支持。
下一特性
以下特性可能在未来添加
- 具有替换的
RandomSampler
- 可迭代数据集的并行
dataloader
- 分布式
dataloader
MSRV
当前MSRV为1.60。
依赖关系
~3.5-6MB
~120K SLoC