1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! Provides a Rust wrapper around Cuda's context.

use super::api::DriverFFI;
use super::memory::*;
use super::{Device, Driver, DriverError};
use crate::device::Error as DeviceError;
use crate::device::{IDevice, MemorySync};
use crate::frameworks::native::device::Cpu;
use crate::frameworks::native::flatbox::FlatBox;
use std::any::Any;
use std::hash::{Hash, Hasher};
use std::rc::Rc;

#[derive(Debug, Clone)]
/// Defines a Cuda Context.
pub struct Context {
    id: Rc<isize>,
    devices: Vec<Device>,
}

impl Drop for Context {
    #[allow(unused_must_use)]
    fn drop(&mut self) {
        let id_c = self.id_c();
        if Rc::get_mut(&mut self.id).is_some() {
            Driver::destroy_context(id_c);
        }
    }
}

impl Context {
    /// Initializes a new Cuda context.
    pub fn new(devices: Device) -> Result<Context, DriverError> {
        Ok(Context::from_c(
            Driver::create_context(devices.clone())?,
            vec![devices],
        ))
    }

    /// Initializes a new Cuda platform from its C type.
    pub fn from_c(id: DriverFFI::CUcontext, devices: Vec<Device>) -> Context {
        Context {
            id: Rc::new(id as isize),
            devices,
        }
    }

    /// Returns the id as isize.
    pub fn id(&self) -> isize {
        *self.id
    }

    /// Returns the id as its C type.
    pub fn id_c(&self) -> DriverFFI::CUcontext {
        *self.id as DriverFFI::CUcontext
    }

    /// Synchronize this Context.
    pub fn synchronize(&self) -> Result<(), DriverError> {
        Driver::synchronize_context()
    }
}

// #[cfg(feature = "native")]
// impl IDeviceSyncOut<FlatBox> for Context {
//     type M = Memory;
//     fn sync_out(&self, source_data: &Memory, dest_data: &mut FlatBox) -> Result<(), DeviceError> {
//         Ok(Driver::mem_cpy_d_to_h(source_data, dest_data)?)
//     }
// }

impl IDevice for Context {
    type H = Device;
    type M = Memory;

    fn id(&self) -> &isize {
        &self.id
    }

    fn hardwares(&self) -> &Vec<Device> {
        &self.devices
    }

    fn alloc_memory(&self, size: DriverFFI::size_t) -> Result<Memory, DeviceError> {
        Ok(Driver::mem_alloc(size)?)
    }
}

impl MemorySync for Context {
    fn sync_in(
        &self,
        my_memory: &mut dyn Any,
        src_device: &dyn Any,
        src_memory: &dyn Any,
    ) -> Result<(), DeviceError> {
        if src_device.downcast_ref::<Cpu>().is_some() {
            let my_mem = my_memory.downcast_mut::<Memory>().unwrap();
            let src_mem = src_memory.downcast_ref::<FlatBox>().unwrap();

            Ok(Driver::mem_cpy_h_to_d(src_mem, my_mem)?)
        } else {
            Err(DeviceError::NoMemorySyncRoute)
        }
    }

    fn sync_out(
        &self,
        my_memory: &dyn Any,
        dst_device: &dyn Any,
        dst_memory: &mut dyn Any,
    ) -> Result<(), DeviceError> {
        if dst_device.downcast_ref::<Cpu>().is_some() {
            let my_mem = my_memory.downcast_ref::<Memory>().unwrap();
            let dst_mem = dst_memory.downcast_mut::<FlatBox>().unwrap();
            Ok(Driver::mem_cpy_d_to_h(my_mem, dst_mem)?)
        } else {
            Err(DeviceError::NoMemorySyncRoute)
        }
    }
}

impl PartialEq for Context {
    fn eq(&self, other: &Self) -> bool {
        self.hardwares() == other.hardwares()
    }
}

impl Eq for Context {}

impl Hash for Context {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.id().hash(state);
    }
}