1个不稳定版本
0.2.0 | 2022年9月20日 |
---|---|
0.1.0 |
|
#405 in 机器学习
34KB
673 行
dlpackrs
这个crate提供了一个安全的Rust绑定到DLPack,它是一个在内存中(主要是)硬件无关的数据格式,被主要的深度学习框架如PyTorch、TensorFlow、MXNet、TVM和主要的数组处理框架如NumPy和CuPy所识别。这个标准的一个重要特性是在特定支持的硬件上提供跨框架的零成本张量转换。
最低支持的Rust版本(MSRV)是稳定工具链1.57.0。
用法
有两个主要的情况与张量底层数据/存储的所有者以及要进行的操作类型相关。
内存管理张量
在这种情况下,ManagedTensor
是由ManagedTensorProxy
构建的,它是不可安全ffi::DLManagedTensor
的安全代理。
普通非内存管理张量
在这种情况下,可以使用(不可变的)Rust包装器Tensor
,如果需要,也可以使用不可安全的ffi::DLTensor
。
示例
当涉及所有者时,可以使用ManagedTensor
。以下是一个双向转换的示例
ndarray::ArrayD <---> ManagedTensor
是在零成本下完成的。
impl<'tensor, C> From<&'tensor mut ArrayD<f32>> for ManagedContext<'tensor, C> {
fn from(t: &'tensor mut ArrayD<f32>) -> Self {
let dlt: Tensor<'tensor> = Tensor::from(t);
let inner = DLManagedTensor::new(dlt.0, None);
ManagedContext(inner)
}
}
impl<'tensor, C> From<&mut ManagedContext<'tensor, C>> for ArrayD<f32> {
fn from(mt: &mut ManagedContext<'tensor, C>) -> Self {
let dlt: DLTensor = mt.0.inner.dl_tensor.into();
unsafe {
let arr = RawArrayViewMut::from_shape_ptr(dlt.shape().unwrap(), dlt.data() as *mut f32);
arr.deref_into_view_mut().into_dyn().to_owned()
}
}
}
如果不涉及所有者,则可以使用Tensor
作为视图。以下是一个双向转换的示例
ndarray::ArrayD <---> Tensor
是在零成本下完成的。
impl<'tensor> From<&'tensor mut ArrayD<f32>> for Tensor<'tensor> {
fn from(arr: &'tensor mut ArrayD<f32>) -> Self {
let inner = DLTensor::new(
arr.as_mut_ptr() as *mut c_void,
Device::default(),
arr.ndim() as i32,
DataType::f32(),
arr.shape().as_ptr() as *const _ as *mut i64,
arr.strides().as_ptr() as *const _ as *mut i64,
0,
);
Tensor(inner)
}
}
impl<'tensor> From<&'tensor mut Tensor<'tensor>> for ArrayD<f32> {
fn from(t: &'tensor mut Tensor<'tensor>) -> Self {
unsafe {
let arr = RawArrayViewMut::from_shape_ptr(t.0.shape().unwrap(), t.0.data() as *mut f32);
arr.deref_into_view_mut().into_dyn().to_owned()
}
}
}
请参阅完整的示例,其中上述情况已被模拟为Rust ndarray转换。
许可证
许可协议为以下之一
- Apache License,版本2.0,(LICENSE-APACHE或https://apache.ac.cn/licenses/LICENSE-2.0)
- 麻省理工学院许可证(《LICENSE-MIT》或http://opensource.org/licenses/MIT)
由你选择。
贡献
除非你明确表示,否则根据Apache-2.0许可证定义,你故意提交以包含在作品中的任何贡献,将按上述方式双重许可,不附加任何额外条款或条件。
依赖项
~0.4–2.7MB
~53K SLoC