1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
use crate::ffi::*;

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum PointerMode {
    Host,
    Device,
}

impl PointerMode {
    pub fn from_c(in_mode: cublasPointerMode_t) -> PointerMode {
        match in_mode {
            cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST => PointerMode::Host,
            cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE => PointerMode::Device,
            _ => unreachable!("wrapping library is newer than this impl, please file a BUG"),
        }
    }

    pub fn as_c(self) -> cublasPointerMode_t {
        match self {
            PointerMode::Host => cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST,
            PointerMode::Device => cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
        }
    }
}

#[derive(Debug, PartialEq, Clone, Copy)]
pub enum Operation {
    NoTrans,
    Trans,
    ConjTrans,
}

impl Operation {
    pub fn from_c(in_mode: cublasOperation_t) -> Operation {
        match in_mode {
            cublasOperation_t::CUBLAS_OP_N => Operation::NoTrans,
            cublasOperation_t::CUBLAS_OP_T => Operation::Trans,
            cublasOperation_t::CUBLAS_OP_C => Operation::ConjTrans,
            _ => unreachable!("wrapping library is newer than this impl, please file a BUG"),
        }
    }

    pub fn as_c(self) -> cublasOperation_t {
        match self {
            Operation::NoTrans => cublasOperation_t::CUBLAS_OP_N,
            Operation::Trans => cublasOperation_t::CUBLAS_OP_T,
            Operation::ConjTrans => cublasOperation_t::CUBLAS_OP_C,
        }
    }
}

// TODO: cublasFillMode_t
// TODO: cublasDiagType_t
// TODO: cublasSideMode_t
// TODO: cublasAtomicsMode_t
// TODO: cublasDataType_t