use crate::co::SharedTensor;
use crate::util::native_backend;
use std::collections::VecDeque;
use std::fmt;
#[derive(Debug)]
pub struct ConfusionMatrix {
num_classes: usize,
capacity: Option<usize>,
samples: VecDeque<Sample>,
}
impl ConfusionMatrix {
pub fn new(num_classes: usize) -> ConfusionMatrix {
ConfusionMatrix {
num_classes: num_classes,
capacity: None,
samples: VecDeque::new(),
}
}
pub fn add_sample(&mut self, prediction: usize, target: usize) {
if self.capacity.is_some() && self.samples.len() >= self.capacity.unwrap() {
self.samples.pop_front();
}
self.samples.push_back(Sample {
prediction: prediction,
target: target,
});
}
pub fn add_samples(&mut self, predictions: &[usize], targets: &[usize]) {
for (&prediction, &target) in predictions.iter().zip(targets.iter()) {
self.add_sample(prediction, target)
}
}
pub fn get_predictions(&self, network_out: &mut SharedTensor<f32>) -> Vec<usize> {
let native_infered = network_out.read(native_backend().device()).unwrap();
let predictions_slice = native_infered.as_slice::<f32>();
let mut predictions = Vec::<usize>::new();
for batch_predictions in predictions_slice.chunks(self.num_classes) {
let mut enumerated_predictions = batch_predictions.iter().enumerate().collect::<Vec<_>>();
enumerated_predictions
.sort_by(|&(_, one), &(_, two)| one.partial_cmp(two).unwrap_or(::std::cmp::Ordering::Equal)); predictions.push(enumerated_predictions.last().unwrap().0)
}
predictions
}
pub fn set_capacity(&mut self, capacity: Option<usize>) {
self.capacity = capacity;
}
pub fn samples(&self) -> &VecDeque<Sample> {
&self.samples
}
pub fn accuracy(&self) -> Accuracy {
let num_samples = self.samples.len();
let num_correct = self.samples.iter().filter(|&&s| s.correct()).count();
Accuracy {
num_samples: num_samples,
num_correct: num_correct,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Sample {
prediction: usize,
target: usize,
}
impl Sample {
pub fn correct(&self) -> bool {
self.prediction == self.target
}
}
impl fmt::Display for Sample {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Prediction: {:?}, Target: {:?}", self.prediction, self.target)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Accuracy {
num_samples: usize,
num_correct: usize,
}
impl Accuracy {
fn ratio(&self) -> f32 {
(self.num_correct as f32) / (self.num_samples as f32) * 100f32
}
}
impl fmt::Display for Accuracy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{:?}/{:?} = {:.2?}%",
self.num_correct,
self.num_samples,
self.ratio()
)
}
}