use super::{Error, API};
use crate::cudnn::Cudnn;
use crate::ffi::*;
#[derive(Debug, Clone)]
pub struct DropoutDescriptor {
id: cudnnDropoutDescriptor_t,
}
impl Drop for DropoutDescriptor {
#[allow(unused_must_use)]
fn drop(&mut self) {
API::destroy_dropout_descriptor(*self.id_c());
}
}
impl DropoutDescriptor {
pub fn new(
handle: &Cudnn,
dropout: f32,
seed: u64,
reserve: *mut libc::c_void,
reserve_size: usize,
) -> Result<DropoutDescriptor, Error> {
let generic_dropout_desc = API::create_dropout_descriptor()?;
API::set_dropout_descriptor(
generic_dropout_desc,
*handle.id_c(),
dropout,
reserve,
reserve_size,
seed,
)?;
Ok(DropoutDescriptor::from_c(generic_dropout_desc))
}
pub fn get_required_size() -> usize {
unimplemented!()
}
pub fn from_c(id: cudnnDropoutDescriptor_t) -> DropoutDescriptor {
DropoutDescriptor { id }
}
pub fn id_c(&self) -> &cudnnDropoutDescriptor_t {
&self.id
}
}