7 个不稳定版本 (3 个重大更改)
0.4.1 | 2024 年 3 月 26 日 |
---|---|
0.4.0 | 2023 年 8 月 21 日 |
0.3.0 | 2023 年 7 月 18 日 |
0.2.2 | 2023 年 6 月 11 日 |
0.1.0 | 2023 年 5 月 9 日 |
#42 in 机器学习
每月下载量 70,123
43KB
993 行
dlpark
纯 Rust 实现的 dmlc/dlpack。
查看 example/with_pyo3 了解用法。
此实现专注于将张量从 Rust 转换到 Python,反之亦然。
它也可以作为不带 pyo3
的 Rust 库使用,默认功能 default-features = false
,请查看 example/from_numpy。
快速入门
我们提供了一个如何将 image::RgbImage
转换到 Python 和 torch.Tensor
转换到 Rust 的简单示例。
Rust $\rightarrow$ Python
我们必须为结构体实现一些特质,才能将其转换为 PyObject
use std::ffi::c_void;
use dlpark::prelude::*;
struct PyRgbImage(image::RgbImage);
impl ToTensor for PyRgbImage {
fn data_ptr(&self) -> *mut std::ffi::c_void {
self.0.as_ptr() as *const c_void as *mut c_void
}
fn byte_offset(&self) -> u64 {
0
}
fn device(&self) -> Device {
Device::CPU
}
fn dtype(&self) -> DataType {
DataType::U8
}
fn shape_and_strides(&self) -> ShapeAndStrides {
ShapeAndStrides::new_contiguous_with_strides(
&[self.0.height(), self.0.width(), 3].map(|x| x as i64),
)
}
}
然后我们可以返回一个 ManagerCtx<PyRgbImage>
#[pyfunction]
fn read_image(filename: &str) -> ManagerCtx<PyRgbImage> {
let img = image::open(filename).unwrap();
let rgb_img = img.to_rgb8();
ManagerCtx::new(PyRgbImage(rgb_img))
}
你可以在 Python 中访问它
import dlparkimg
from torch.utils.dlpack import to_dlpack, from_dlpack
import matplotlib.pyplot as plt
tensor = from_dlpack(dlparkimg.read_image("candy.jpg"))
print(tensor.shape)
plt.imshow(tensor.numpy())
plt.show()
如果你想将其转换为 numpy.ndarray
,你可以创建一个简单的包装器,如下所示
import numpy as np
import dlparkimg
class FakeTensor:
def __init__(self, x):
self.x = x
def __dlpack__(self):
return self.x
arr = np.from_dlpack(FakeTensor(dlparkimg.read_image("candy.jpg")))
Python $\rightarrow$ Rust
ManagedTensor
持有张量的内存并提供访问张量属性的方法。
#[pyfunction]
fn write_image(filename: &str, tensor: ManagedTensor) {
let buf = tensor.as_slice::<u8>();
let rgb_img = image::ImageBuffer::<Rgb<u8>, _>::from_raw(
tensor.shape()[1] as u32,
tensor.shape()[0] as u32,
buf,
)
.unwrap();
rgb_img.save(filename).unwrap();
}
你可以在 Python 中调用它
import dlparkimg
from torch.utils.dlpack import to_dlpack, from_dlpack
bgr_img = tensor[..., [2, 1, 0]] # [H, W, C=3]
dlparkimg.write_image('bgr.jpg', to_dlpack(bgr_img))
依赖项
~0–5.5MB
~16K SLoC