use super::utils::DataType;
use super::{Error, API};
use crate::ffi::*;
#[derive(Debug, Clone)]
pub struct TensorDescriptor {
id: cudnnTensorDescriptor_t,
}
pub fn tensor_vec_id_c(tensor_vec: &[TensorDescriptor]) -> Vec<cudnnTensorDescriptor_t> {
tensor_vec.iter().map(|tensor| *tensor.id_c()).collect()
}
impl Drop for TensorDescriptor {
#[allow(unused_must_use)]
fn drop(&mut self) {
API::destroy_tensor_descriptor(*self.id_c());
}
}
impl TensorDescriptor {
pub fn new(
dims: &[i32],
strides: &[i32],
data_type: DataType,
) -> Result<TensorDescriptor, Error> {
let nb_dims = dims.len() as i32;
if nb_dims < 3 {
return Err(Error::BadParam(
"CUDA cuDNN only supports Tensors with 3 to 8 dimensions.",
));
}
let dims_ptr = dims.as_ptr();
let strides_ptr = strides.as_ptr();
let generic_tensor_desc = API::create_tensor_descriptor()?;
let data_type = API::cudnn_data_type(data_type);
API::set_tensor_descriptor(
generic_tensor_desc,
data_type,
nb_dims,
dims_ptr,
strides_ptr,
)?;
Ok(TensorDescriptor::from_c(generic_tensor_desc))
}
pub fn from_c(id: cudnnTensorDescriptor_t) -> TensorDescriptor {
TensorDescriptor { id }
}
pub fn id_c(&self) -> &cudnnTensorDescriptor_t {
&self.id
}
}