use crate::co::SharedTensor;
use crate::util::native_backend;
use std::collections::VecDeque;
use std::fmt;
#[derive(Debug)]
pub struct RegressionEvaluator {
evaluation_metric: String,
capacity: Option<usize>,
samples: VecDeque<Sample>,
}
impl RegressionEvaluator {
pub fn new(evaluation_metric: Option<String>) -> RegressionEvaluator {
RegressionEvaluator {
evaluation_metric: evaluation_metric.unwrap_or("mse".to_string()),
capacity: None,
samples: VecDeque::new(),
}
}
pub fn add_sample(&mut self, prediction: f32, target: f32) {
if self.capacity.is_some() && self.samples.len() >= self.capacity.unwrap() {
self.samples.pop_front();
}
self.samples.push_back(Sample { prediction, target });
}
pub fn add_samples(&mut self, predictions: &[f32], targets: &[f32]) {
for (&prediction, &target) in predictions.iter().zip(targets.iter()) {
self.add_sample(prediction, target)
}
}
pub fn get_predictions(&self, network_out: &mut SharedTensor<f32>) -> Vec<f32> {
let native_inferred = network_out.read(native_backend().device()).unwrap();
native_inferred.as_slice::<f32>().to_vec()
}
pub fn set_capacity(&mut self, capacity: Option<usize>) {
self.capacity = capacity;
}
pub fn samples(&self) -> &VecDeque<Sample> {
&self.samples
}
pub fn accuracy(&self) -> impl RegressionLoss {
let num_samples = self.samples.len();
match &*self.evaluation_metric {
"mse" => {
let sum_squared_error = self
.samples
.iter()
.fold(0.0, |acc, sample| acc + (sample.prediction - sample.target).powi(2));
MeanSquaredErrorAccuracy {
num_samples,
sum_squared_error,
}
}
_ => unimplemented!(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Sample {
prediction: f32,
target: f32,
}
impl fmt::Display for Sample {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Prediction: {:.2?}, Target: {:.2?}", self.prediction, self.target)
}
}
pub trait RegressionLoss {
fn loss(&self) -> f32;
}
impl RegressionLoss for MeanSquaredErrorAccuracy {
fn loss(&self) -> f32 {
self.sum_squared_error / self.num_samples as f32
}
}
#[derive(Debug, Clone, Copy)]
pub struct MeanSquaredErrorAccuracy {
num_samples: usize,
sum_squared_error: f32,
}
#[allow(trivial_casts)]
impl fmt::Display for dyn RegressionLoss {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, " {:.6?}", self.loss())
}
}