extern "C" {}
pub use self::api::{Driver, DriverError};
pub use self::context::Context;
pub use self::device::{Device, DeviceInfo};
pub use self::function::Function;
pub use self::memory::Memory;
pub use self::module::Module;
use crate::backend::{Backend, IBackend};
use crate::cublas;
use crate::cudnn::*;
use crate::framework::IFramework;
use crate::BackendConfig;
mod api;
pub mod context;
pub mod device;
pub mod function;
pub mod memory;
pub mod module;
pub fn get_cuda_backend() -> Backend<Cuda> {
let framework = Cuda::new();
let hardwares = framework.hardwares()[0..1].to_vec();
let backend_config = BackendConfig::new(framework, &hardwares);
let mut backend = Backend::new(backend_config).unwrap();
backend.framework.initialise_cublas().unwrap();
backend.framework.initialise_cudnn().unwrap();
backend
}
#[derive(Debug, Clone)]
pub struct Cuda {
hardwares: Vec<Device>,
binary: Module,
cudnn: Option<Cudnn>,
cublas: Option<cublas::Context>,
}
impl Cuda {
pub fn initialise_cublas(&mut self) -> Result<(), crate::framework::Error> {
self.cublas = {
let mut context = cublas::Context::new().unwrap();
context
.set_pointer_mode(cublas::api::PointerMode::Device)
.unwrap();
Some(context)
};
Ok(())
}
pub fn initialise_cudnn(&mut self) -> Result<(), crate::framework::Error> {
self.cudnn = match Cudnn::new() {
Ok(cudnn_ptr) => Some(cudnn_ptr),
Err(_) => None,
};
Ok(())
}
pub fn cudnn(&self) -> &Cudnn {
match &self.cudnn {
Some(cudnn) => cudnn,
None => panic!("Couldn't find a CUDNN Handle - Initialise CUDNN has not been called"),
}
}
pub fn cublas(&self) -> &cublas::Context {
match &self.cublas {
Some(cublas) => cublas,
None => panic!("Couldn't find a CUBLAS Handle - Initialise CUBLAS has not been called"),
}
}
}
impl IFramework for Cuda {
type H = Device;
type D = Context;
type B = Module;
fn ID() -> &'static str {
"CUDA"
}
fn new() -> Cuda {
if let Err(err) = Driver::init() {
panic!("Unable to initialize Cuda Framework: {}", err);
}
match Cuda::load_hardwares() {
Ok(hardwares) => Cuda {
hardwares,
binary: Module::from_isize(1),
cudnn: None,
cublas: None,
},
Err(err) => panic!("Could not initialize Cuda Framework, due to: {}", err),
}
}
fn load_hardwares() -> Result<Vec<Device>, crate::framework::Error> {
Ok(Driver::load_devices()?)
}
fn hardwares(&self) -> &[Device] {
&self.hardwares
}
fn binary(&self) -> &Self::B {
&self.binary
}
fn new_device(&self, hardwares: &[Device]) -> Result<Self::D, crate::framework::Error> {
let length = hardwares.len();
match length {
0 => Err(crate::framework::Error::Implementation("No device for context specified.".to_string())),
1 => Ok(Context::new(hardwares[0].clone())?),
_ => Err(crate::framework::Error::Implementation("Cuda's `new_device` method currently supports only one Harware for Device creation.".to_string()))
}
}
}
impl IBackend for Backend<Cuda> {
type F = Cuda;
fn device(&self) -> &Context {
&self.device()
}
fn synchronize(&self) -> Result<(), crate::framework::Error> {
Ok(self.device().synchronize()?)
}
}