1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//! Provides common utility functions

use crate::co::frameworks::native::flatbox::FlatBox;
use crate::co::prelude::*;
use crate::coblas::plugin::*;
use crate::conn;
use num::traits::{cast, NumCast};
use std::sync::{Arc, RwLock};

/// Shared Lock used for our tensors
pub type ArcLock<T> = Arc<RwLock<T>>;

/// Create a simple native backend.
///
/// This is handy when you need to sync data to host memory to read/write it.
pub fn native_backend() -> Backend<Native> {
    let framework = Native::new();
    let hardwares = &framework.hardwares().to_vec();
    let backend_config = BackendConfig::new(framework, hardwares);
    Backend::new(backend_config).unwrap()
}

/// Write into a native Coaster Memory.
pub fn write_to_memory<T: NumCast + ::std::marker::Copy>(mem: &mut FlatBox, data: &[T]) {
    write_to_memory_offset(mem, data, 0);
}

/// Write into a native Coaster Memory with a offset.
pub fn write_to_memory_offset<T: NumCast + ::std::marker::Copy>(mem: &mut FlatBox, data: &[T], offset: usize) {
    let mem_buffer = mem.as_mut_slice::<f32>();
    for (index, datum) in data.iter().enumerate() {
        // mem_buffer[index + offset] = *datum;
        mem_buffer[index + offset] = cast(*datum).unwrap();
    }
}

/// Write the `i`th sample of a batch into a SharedTensor.
///
/// The size of a single sample is inferred through
/// the first dimension of the SharedTensor, which
/// is assumed to be the batchsize.
///
/// Allocates memory on a Native Backend if neccessary.
pub fn write_batch_sample<T: NumCast + ::std::marker::Copy>(tensor: &mut SharedTensor<f32>, data: &[T], i: usize) {
    let native_backend = native_backend();
    let tensor_desc = tensor.desc();
    let batch_size = tensor_desc[0];
    let batch_sample_size = tensor_desc.size();
    let sample_size = batch_sample_size / batch_size;

    write_to_memory_offset(
        tensor.write_only(native_backend.device()).unwrap(),
        &data,
        i * sample_size,
    );
}

/// Create a Coaster SharedTensor for a scalar value.
pub fn native_scalar<T: NumCast + ::std::marker::Copy>(scalar: T) -> SharedTensor<T> {
    let native = native_backend();
    let mut shared_scalar = SharedTensor::<T>::new(&[1]);
    write_to_memory(shared_scalar.write_only(native.device()).unwrap(), &[scalar]);
    shared_scalar
}

/// Casts a Vec<usize> to as Vec<i32>
pub fn cast_vec_usize_to_i32(input: Vec<usize>) -> Vec<i32> {
    let mut out = Vec::new();
    for i in input.iter() {
        out.push(*i as i32);
    }
    out
}

/// Extends IBlas with Axpby
pub trait Axpby<F>: Axpy<F> + Scal<F> {
    /// Performs the operation y := a*x + b*y .
    ///
    /// Consists of a scal(b, y) followed by a axpby(a,x,y).
    fn axpby(
        &self,
        a: &SharedTensor<F>,
        x: &SharedTensor<F>,
        b: &SharedTensor<F>,
        y: &mut SharedTensor<F>,
    ) -> Result<(), crate::co::error::Error> {
        self.scal(b, y)?;
        self.axpy(a, x, y)?;
        Ok(())
    }
}

impl<T: Axpy<f32> + Scal<f32>> Axpby<f32> for T {}

/// Encapsulates all traits required by Solvers.
// pub trait SolverOps<F> : Axpby<F> + Dot<F> + Copy<F> {}
//
// impl<T: Axpby<f32> + Dot<f32> + Copy<f32>> SolverOps<f32> for T {}
pub trait SolverOps<F>: LayerOps<F> + Axpby<F> + Dot<F> + Copy<F> {}

impl<T: LayerOps<f32> + Axpby<f32> + Dot<f32> + Copy<f32>> SolverOps<f32> for T {}

/// Encapsulates all traits used in Layers.
pub trait LayerOps<F>:
    conn::Convolution<F>
    + conn::Rnn<F>
    + conn::Pooling<F>
    + conn::Relu<F>
    + conn::ReluPointwise<F>
    + conn::Sigmoid<F>
    + conn::SigmoidPointwise<F>
    + conn::Tanh<F>
    + conn::TanhPointwise<F>
    + conn::Softmax<F>
    + conn::LogSoftmax<F>
    + conn::Dropout<F>
    + Gemm<F>
    + Axpby<F>
    + Copy<F>
{
}

impl<
        T: conn::Convolution<f32>
            + conn::Rnn<f32>
            + conn::Pooling<f32>
            + conn::Relu<f32>
            + conn::ReluPointwise<f32>
            + conn::Sigmoid<f32>
            + conn::SigmoidPointwise<f32>
            + conn::Tanh<f32>
            + conn::TanhPointwise<f32>
            + conn::Softmax<f32>
            + conn::LogSoftmax<f32>
            + conn::Dropout<f32>
            + Gemm<f32>
            + Axpby<f32>
            + Copy<f32>,
    > LayerOps<f32> for T
{
}