use super::utils::DataType;
use super::{Error, API};
use crate::ffi::*;
#[derive(Debug, Clone)]
pub struct FilterDescriptor {
id: cudnnFilterDescriptor_t,
}
impl Drop for FilterDescriptor {
#[allow(unused_must_use)]
fn drop(&mut self) {
API::destroy_filter_descriptor(*self.id_c());
}
}
impl FilterDescriptor {
pub fn new(filter_dim: &[i32], data_type: DataType) -> Result<FilterDescriptor, Error> {
let nb_dims = filter_dim.len() as i32;
let tensor_format = cudnnTensorFormat_t::CUDNN_TENSOR_NCHW;
let data_type = API::cudnn_data_type(data_type);
let generic_filter_desc = API::create_filter_descriptor()?;
API::set_filter_descriptor(
generic_filter_desc,
data_type,
tensor_format,
nb_dims,
filter_dim.as_ptr(),
)?;
Ok(FilterDescriptor::from_c(generic_filter_desc))
}
pub fn from_c(id: cudnnFilterDescriptor_t) -> FilterDescriptor {
FilterDescriptor { id }
}
pub fn id_c(&self) -> &cudnnFilterDescriptor_t {
&self.id
}
}