1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! TODO: DOC

use crate::co::SharedTensor;
use crate::util::native_backend;
use std::collections::VecDeque;
use std::fmt;
/// A [ConfusionMatrix][wiki].
///
/// [wiki]: https://en.wikipedia.org/wiki/Confusion_matrix
#[derive(Debug)]
pub struct ConfusionMatrix {
    num_classes: usize,

    /// maximum number of samples held
    capacity: Option<usize>,
    samples: VecDeque<Sample>,
}

impl ConfusionMatrix {
    /// Create a ConfusionMatrix that analyzes the prediction of `num_classes` classes.
    pub fn new(num_classes: usize) -> ConfusionMatrix {
        ConfusionMatrix {
            num_classes: num_classes,
            capacity: None,
            samples: VecDeque::new(),
        }
    }

    /// Add a sample by providing the expected `target` class and the `prediction`.
    pub fn add_sample(&mut self, prediction: usize, target: usize) {
        if self.capacity.is_some() && self.samples.len() >= self.capacity.unwrap() {
            self.samples.pop_front();
        }
        self.samples.push_back(Sample {
            prediction: prediction,
            target: target,
        });
    }

    /// Add a batch of samples.
    ///
    /// See [add_sample](#method.add_sample).
    pub fn add_samples(&mut self, predictions: &[usize], targets: &[usize]) {
        for (&prediction, &target) in predictions.iter().zip(targets.iter()) {
            self.add_sample(prediction, target)
        }
    }

    /// Get the predicted classes from the output of a network.
    ///
    /// The prediction for each sample of the batch is found by
    /// determining which output value had the smallest loss.
    pub fn get_predictions(&self, network_out: &mut SharedTensor<f32>) -> Vec<usize> {
        let native_infered = network_out.read(native_backend().device()).unwrap();
        let predictions_slice = native_infered.as_slice::<f32>();

        let mut predictions = Vec::<usize>::new();
        for batch_predictions in predictions_slice.chunks(self.num_classes) {
            let mut enumerated_predictions = batch_predictions.iter().enumerate().collect::<Vec<_>>();
            enumerated_predictions
                .sort_by(|&(_, one), &(_, two)| one.partial_cmp(two).unwrap_or(::std::cmp::Ordering::Equal)); // find index of prediction
            predictions.push(enumerated_predictions.last().unwrap().0)
        }
        predictions
    }

    /// Set the `capacity` of the ConfusionMatrix
    pub fn set_capacity(&mut self, capacity: Option<usize>) {
        self.capacity = capacity;
        // TODO: truncate if over capacity
    }

    /// Return all collected samples.
    pub fn samples(&self) -> &VecDeque<Sample> {
        &self.samples
    }

    /// Return the accuracy of the collected predictions.
    pub fn accuracy(&self) -> Accuracy {
        let num_samples = self.samples.len();
        let num_correct = self.samples.iter().filter(|&&s| s.correct()).count();
        Accuracy {
            num_samples: num_samples,
            num_correct: num_correct,
        }
    }
}

/// A single prediction Sample.
#[derive(Debug, Clone, Copy)]
pub struct Sample {
    prediction: usize,
    target: usize,
}

impl Sample {
    /// Returns if the prediction is equal to the expected target.
    pub fn correct(&self) -> bool {
        self.prediction == self.target
    }
}

impl fmt::Display for Sample {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "Prediction: {:?}, Target: {:?}", self.prediction, self.target)
    }
}

#[derive(Debug, Clone, Copy)]
/// The accuracy of the predictions in a ConfusionMatrix.
///
/// Used to print the accuracy.
pub struct Accuracy {
    num_samples: usize,
    num_correct: usize,
}

impl Accuracy {
    fn ratio(&self) -> f32 {
        (self.num_correct as f32) / (self.num_samples as f32) * 100f32
    }
}

impl fmt::Display for Accuracy {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(
            f,
            "{:?}/{:?} = {:.2?}%",
            self.num_correct,
            self.num_samples,
            self.ratio()
        )
    }
}