#mnist #file-format #machine-learning #encoding

mnist_read

读取 MNIST 文件格式的通用数据和标签文件

2 个稳定版本

1.0.1 2020 年 11 月 22 日

#5 in #mnist


用于 2 crates

Apache-2.0

6KB

MNIST Read

Crates.io lib.rs.io docs

用于 Rust 的 MNIST 文件格式的通用数据和标签文件读取器。

就是这样简单。

// Raw format
let train_labels: Vec<u8> = mnist_read::read_labels("train-labels.idx1-ubyte");
let train_data: Vec<u8> = mnist_read::read_data("train-images.idx3-ubyte");

// Ndarray (Maths lib)
let usize_labels:Vec<usize> = train__labels.into_iter().map(|l|l as usize).collect();
let mut array_labels:ndarray::Array2<usize> = ndarray::Array::from_shape_vec((10000, 1), usize_labels).expect("Bad labels");

let f32_data:Vec<f32> = train_data.into_iter().map(|d|d as f32 / 255f32).collect();
let mut array_data:ndarray::Array2<f32> = ndarray::Array::from_shape_vec((10000, 28*28), f32_data).expect("Bad data");

// Cogent (Neural network library)
let mut net = cogent::NeuralNetwork::new(784,&[
    cogent::Layer::Dense(1000, cogent::Activation::ReLU),
    cogent::Layer::Dropout(0.2),
    cogent::Layer::Dense(500, cogent::Activation::ReLU),
    cogent::Layer::Dropout(0.2),
    cogent::Layer::Dense(10, cogent::Activation::Softmax)
])
net.train(&mut array_data, &mut array_labels).go()

依赖关系

~1.5MB
~25K SLoC