use super::Context;
use crate::ffi::*;
use crate::{Error, API};
impl API {
pub fn asum(
context: &Context,
x: *mut f32,
result: *mut f32,
n: i32,
stride: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride.unwrap_or(1);
unsafe { Self::ffi_sasum(*context.id_c(), n, x, stride_x, result) }
}
unsafe fn ffi_sasum(
handle: cublasHandle_t,
n: i32,
x: *mut f32,
incx: i32,
result: *mut f32,
) -> Result<(), Error> {
match cublasSasum_v2(handle, n, x, incx, result) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => Err(Error::AllocFailed),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to calculate sum of x.",
status as i32 as u64,
)),
}
}
pub fn axpy(
context: &Context,
alpha: *mut f32,
x: *mut f32,
y: *mut f32,
n: i32,
stride_x: Option<i32>,
stride_y: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
let stride_y = stride_y.unwrap_or(1);
unsafe { Self::ffi_saxpy(*context.id_c(), n, alpha, x, stride_x, y, stride_y) }
}
unsafe fn ffi_saxpy(
handle: cublasHandle_t,
n: i32,
alpha: *mut f32,
x: *mut f32,
incx: i32,
y: *mut f32,
incy: i32,
) -> Result<(), Error> {
match cublasSaxpy_v2(handle, n, alpha, x, incx, y, incy) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to calculate axpy (alpha * x + y).",
status as i32 as u64,
)),
}
}
pub fn copy(
context: &Context,
x: *mut f32,
y: *mut f32,
n: i32,
stride_x: Option<i32>,
stride_y: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
let stride_y = stride_y.unwrap_or(1);
unsafe { Self::ffi_scopy(*context.id_c(), n, x, stride_x, y, stride_y) }
}
unsafe fn ffi_scopy(
handle: cublasHandle_t,
n: i32,
x: *mut f32,
incx: i32,
y: *mut f32,
incy: i32,
) -> Result<(), Error> {
match cublasScopy_v2(handle, n, x, incx, y, incy) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to calculate copy from x to y.",
status as i32 as u64,
)),
}
}
pub fn dot(
context: &Context,
x: *mut f32,
y: *mut f32,
result: *mut f32,
n: i32,
stride_x: Option<i32>,
stride_y: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
let stride_y = stride_y.unwrap_or(1);
unsafe { Self::ffi_sdot(*context.id_c(), n, x, stride_x, y, stride_y, result) }
}
unsafe fn ffi_sdot(
handle: cublasHandle_t,
n: i32,
x: *mut f32,
incx: i32,
y: *mut f32,
incy: i32,
result: *mut f32,
) -> Result<(), Error> {
match cublasSdot_v2(handle, n, x, incx, y, incy, result) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to calculate dot product of x and y.",
status as i32 as u64,
)),
}
}
pub fn nrm2(
context: &Context,
x: *mut f32,
result: *mut f32,
n: i32,
stride_x: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
unsafe { Self::ffi_snrm2(*context.id_c(), n, x, stride_x, result) }
}
unsafe fn ffi_snrm2(
handle: cublasHandle_t,
n: i32,
x: *mut f32,
incx: i32,
result: *mut f32,
) -> Result<(), Error> {
match cublasSnrm2_v2(handle, n, x, incx, result) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => {
dbg!("Alloc failed");
Err(Error::AllocFailed)
}
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => {
dbg!("Unknown!");
Err(Error::Unknown(
"Unable to calculate the euclidian norm of x.",
status as i32 as u64,
))
}
}
}
pub fn scal(
context: &Context,
alpha: *mut f32,
x: *mut f32,
n: i32,
stride_x: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
unsafe { Self::ffi_sscal(*context.id_c(), n, alpha, x, stride_x) }
}
unsafe fn ffi_sscal(
handle: cublasHandle_t,
n: i32,
alpha: *mut f32,
x: *mut f32,
incx: i32,
) -> Result<(), Error> {
match cublasSscal_v2(handle, n, alpha, x, incx) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to scale the vector x.",
status as i32 as u64,
)),
}
}
pub fn swap(
context: &Context,
x: *mut f32,
y: *mut f32,
n: i32,
stride_x: Option<i32>,
stride_y: Option<i32>,
) -> Result<(), Error> {
let stride_x = stride_x.unwrap_or(1);
let stride_y = stride_y.unwrap_or(1);
unsafe { Self::ffi_sswap(*context.id_c(), n, x, stride_x, y, stride_y) }
}
unsafe fn ffi_sswap(
handle: cublasHandle_t,
n: i32,
x: *mut f32,
incx: i32,
y: *mut f32,
incy: i32,
) -> Result<(), Error> {
match cublasSswap_v2(handle, n, x, incx, y, incy) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to swap vector x and y.",
status as i32 as u64,
)),
}
}
}
#[cfg(test)]
mod test {
use crate::api::context::Context;
use crate::api::enums::PointerMode;
use crate::chore::*;
use crate::co::tensor::SharedTensor;
use crate::API;
#[test]
fn use_cuda_memory_for_asum() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let n = 20i32;
let val = 2f32;
let x = filled_tensor(&native, n as usize, val);
let mut result = SharedTensor::<f32>::new(&vec![1]);
{
let cuda_mem = x.read(cuda.device()).unwrap();
let cuda_mem_result = result.write_only(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem.id_c());
let res_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_result.id_c());
API::ffi_sasum(*ctx.id_c(), n, x_addr, 1, res_addr).unwrap();
}
}
let native_res = result.read(native.device()).unwrap();
assert_eq!(&[40f32], native_res.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_axpy() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let alpha = filled_tensor(&native, 1, 1.5f32);
let n = 5i32;
let val = 2f32;
let x = filled_tensor(&native, n as usize, val);
let val = 4f32;
let mut y = filled_tensor(&native, n as usize, val);
{
let cuda_mem_alpha = alpha.read(cuda.device()).unwrap();
let cuda_mem_x = x.read(cuda.device()).unwrap();
let cuda_mem_y = y.read_write(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let alpha_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_alpha.id_c());
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
let y_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_y.id_c());
API::ffi_saxpy(*ctx.id_c(), n, alpha_addr, x_addr, 1, y_addr, 1).unwrap();
}
}
let native_y = y.read(native.device()).unwrap();
assert_eq!(&[7f32, 7f32, 7f32, 7f32, 7f32], native_y.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_copy() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let n = 5i32;
let val = 2f32;
let x = filled_tensor(&native, n as usize, val);
let val = 4f32;
let mut y = filled_tensor(&native, n as usize, val);
{
let cuda_mem_x = x.read(cuda.device()).unwrap();
let cuda_mem_y = y.write_only(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
let y_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_y.id_c());
API::ffi_scopy(*ctx.id_c(), n, x_addr, 1, y_addr, 1).unwrap();
}
}
let native_y = y.read(native.device()).unwrap();
assert_eq!(&[2f32, 2f32, 2f32, 2f32, 2f32], native_y.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_dot() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let n = 5i32;
let val = 2f32;
let x = filled_tensor(&native, n as usize, val);
let val = 4f32;
let y = filled_tensor(&native, n as usize, val);
let mut result = SharedTensor::<f32>::new(&vec![1]);
{
let cuda_mem_x = x.read(cuda.device()).unwrap();
let cuda_mem_y = y.read(cuda.device()).unwrap();
let cuda_mem_result = result.write_only(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
let y_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_y.id_c());
let result_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_result.id_c());
API::ffi_sdot(*ctx.id_c(), n, x_addr, 1, y_addr, 1, result_addr).unwrap();
}
}
let native_result = result.read(native.device()).unwrap();
assert_eq!(&[40f32], native_result.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_nrm2() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let n = 3i32;
let val = 2f32;
let mut x = filled_tensor(&native, n as usize, val);
write_to_memory(x.write_only(native.device()).unwrap(), &[2f32, 2f32, 1f32]);
let mut result = SharedTensor::<f32>::new(&vec![1]);
{
let cuda_mem_x = x.read(cuda.device()).unwrap();
let cuda_mem_result = result.write_only(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
let result_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_result.id_c());
API::ffi_snrm2(*ctx.id_c(), n, x_addr, 1, result_addr).unwrap();
}
}
let native_result = result.read(native.device()).unwrap();
assert_eq!(&[3f32], native_result.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_scal() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let alpha = filled_tensor(&native, 1, 2.5f32);
let n = 3i32;
let val = 2f32;
let mut x = filled_tensor(&native, n as usize, val);
{
let cuda_mem_alpha = alpha.read(cuda.device()).unwrap();
let cuda_mem_x = x.read_write(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let alpha_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_alpha.id_c());
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
API::ffi_sscal(*ctx.id_c(), n, alpha_addr, x_addr, 1).unwrap();
}
}
let native_x = x.read(native.device()).unwrap();
assert_eq!(&[5f32, 5f32, 5f32], native_x.as_slice::<f32>());
test_teardown();
}
#[test]
fn use_cuda_memory_for_swap() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let n = 5i32;
let val = 2f32;
let mut x = filled_tensor(&native, n as usize, val);
let val = 4f32;
let mut y = filled_tensor(&native, n as usize, val);
{
let cuda_mem_x = x.read_write(cuda.device()).unwrap();
let cuda_mem_y = y.read_write(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let x_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_x.id_c());
let y_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_y.id_c());
API::ffi_sswap(*ctx.id_c(), n, x_addr, 1, y_addr, 1).unwrap();
}
}
let native_x = x.read(native.device()).unwrap();
assert_eq!(&[4f32, 4f32, 4f32, 4f32, 4f32], native_x.as_slice::<f32>());
let native_y = y.read(native.device()).unwrap();
assert_eq!(&[2f32, 2f32, 2f32, 2f32, 2f32], native_y.as_slice::<f32>());
test_teardown();
}
}