7 个不稳定版本 (3 个重大更改)

0.4.1 2024 年 3 月 26 日
0.4.0 2023 年 8 月 21 日
0.3.0 2023 年 7 月 18 日
0.2.2 2023 年 6 月 11 日
0.1.0 2023 年 5 月 9 日

#42 in 机器学习

Download history 130/week @ 2024-04-27 1694/week @ 2024-05-04 3543/week @ 2024-05-11 2102/week @ 2024-05-18 5556/week @ 2024-05-25 3884/week @ 2024-06-01 6565/week @ 2024-06-08 7826/week @ 2024-06-15 7483/week @ 2024-06-22 10286/week @ 2024-06-29 12142/week @ 2024-07-06 11442/week @ 2024-07-13 17141/week @ 2024-07-20 23384/week @ 2024-07-27 14515/week @ 2024-08-03 12832/week @ 2024-08-10

每月下载量 70,123

Apache-2.0

43KB
993

dlpark

Github Actions Crates.io docs.rs

纯 Rust 实现的 dmlc/dlpack

查看 example/with_pyo3 了解用法。

此实现专注于将张量从 Rust 转换到 Python,反之亦然。

它也可以作为不带 pyo3 的 Rust 库使用,默认功能 default-features = false,请查看 example/from_numpy

快速入门

我们提供了一个如何将 image::RgbImage 转换到 Python 和 torch.Tensor 转换到 Rust 的简单示例。

完整代码在这里.

Rust $\rightarrow$ Python

我们必须为结构体实现一些特质,才能将其转换为 PyObject

use std::ffi::c_void;
use dlpark::prelude::*;

struct PyRgbImage(image::RgbImage);

impl ToTensor for PyRgbImage {
    fn data_ptr(&self) -> *mut std::ffi::c_void {
        self.0.as_ptr() as *const c_void as *mut c_void
    }

    fn byte_offset(&self) -> u64 {
        0
    }


    fn device(&self) -> Device {
        Device::CPU
    }

    fn dtype(&self) -> DataType {
        DataType::U8
    }

    fn shape_and_strides(&self) -> ShapeAndStrides {
        ShapeAndStrides::new_contiguous_with_strides(
            &[self.0.height(), self.0.width(), 3].map(|x| x as i64),
        )
    }
}

然后我们可以返回一个 ManagerCtx<PyRgbImage>

#[pyfunction]
fn read_image(filename: &str) -> ManagerCtx<PyRgbImage> {
    let img = image::open(filename).unwrap();
    let rgb_img = img.to_rgb8();
    ManagerCtx::new(PyRgbImage(rgb_img))
}

你可以在 Python 中访问它

import dlparkimg
from torch.utils.dlpack import to_dlpack, from_dlpack
import matplotlib.pyplot as plt

tensor = from_dlpack(dlparkimg.read_image("candy.jpg"))

print(tensor.shape)
plt.imshow(tensor.numpy())
plt.show()

如果你想将其转换为 numpy.ndarray,你可以创建一个简单的包装器,如下所示

import numpy as np
import dlparkimg

class FakeTensor:

    def __init__(self, x):
        self.x = x

    def __dlpack__(self):
        return self.x

arr = np.from_dlpack(FakeTensor(dlparkimg.read_image("candy.jpg")))

Python $\rightarrow$ Rust

ManagedTensor 持有张量的内存并提供访问张量属性的方法。

#[pyfunction]
fn write_image(filename: &str, tensor: ManagedTensor) {
    let buf = tensor.as_slice::<u8>();

    let rgb_img = image::ImageBuffer::<Rgb<u8>, _>::from_raw(
        tensor.shape()[1] as u32,
        tensor.shape()[0] as u32,
        buf,
    )
    .unwrap();

    rgb_img.save(filename).unwrap();
}

你可以在 Python 中调用它

import dlparkimg
from torch.utils.dlpack import to_dlpack, from_dlpack

bgr_img = tensor[..., [2, 1, 0]] # [H, W, C=3]
dlparkimg.write_image('bgr.jpg', to_dlpack(bgr_img))

依赖项

~0–5.5MB
~16K SLoC