use crate::capnp_util::*;
use crate::co::{IBackend, SharedTensor};
use crate::juice_capnp::reshape_config as capnp_config;
use crate::layer::*;
use crate::util::ArcLock;
#[derive(Debug, Clone)]
pub struct Reshape {
shape: Vec<usize>,
}
impl Reshape {
pub fn from_config(config: &ReshapeConfig) -> Reshape {
Reshape {
shape: config.shape.clone(),
}
}
}
impl<B: IBackend> ILayer<B> for Reshape {
fn compute_in_place(&self) -> bool {
true
}
fn auto_output_blobs(&self) -> bool {
false
}
fn reshape(
&mut self,
backend: ::std::rc::Rc<B>,
input_data: &mut Vec<ArcLock<SharedTensor<f32>>>,
input_gradient: &mut Vec<ArcLock<SharedTensor<f32>>>,
weights_data: &mut Vec<ArcLock<SharedTensor<f32>>>,
weights_gradient: &mut Vec<ArcLock<SharedTensor<f32>>>,
output_data: &mut Vec<ArcLock<SharedTensor<f32>>>,
output_gradient: &mut Vec<ArcLock<SharedTensor<f32>>>,
) {
output_data[0].write().unwrap().resize(&self.shape).unwrap();
output_gradient[0].write().unwrap().resize(&self.shape).unwrap();
}
}
impl<B: IBackend> ComputeOutput<f32, B> for Reshape {
fn compute_output(
&self,
backend: &B,
_weights: &[&SharedTensor<f32>],
input_data: &[&SharedTensor<f32>],
output_data: &mut [&mut SharedTensor<f32>],
) {
}
}
impl<B: IBackend> ComputeInputGradient<f32, B> for Reshape {
fn compute_input_gradient(
&self,
backend: &B,
weights_data: &[&SharedTensor<f32>],
output_data: &[&SharedTensor<f32>],
output_gradients: &[&SharedTensor<f32>],
input_data: &[&SharedTensor<f32>],
input_gradients: &mut [&mut SharedTensor<f32>],
) {
}
}
impl<B: IBackend> ComputeParametersGradient<f32, B> for Reshape {}
#[derive(Debug, Clone)]
pub struct ReshapeConfig {
pub shape: Vec<usize>,
}
impl ReshapeConfig {
pub fn of_shape(shape: &[usize]) -> ReshapeConfig {
ReshapeConfig {
shape: shape.to_owned(),
}
}
}
impl<'a> CapnpWrite<'a> for ReshapeConfig {
type Builder = capnp_config::Builder<'a>;
fn write_capnp(&self, builder: &mut Self::Builder) {
let mut shape = builder.reborrow().init_shape(self.shape.len() as u32);
for (i, dim) in self.shape.iter().enumerate() {
shape.set(i as u32, *dim as u64);
}
}
}
impl<'a> CapnpRead<'a> for ReshapeConfig {
type Reader = capnp_config::Reader<'a>;
fn read_capnp(reader: Self::Reader) -> Self {
let read_shape = reader.get_shape().unwrap();
let mut shape = Vec::new();
for i in 0..read_shape.len() {
shape.push(read_shape.get(i) as usize)
}
ReshapeConfig { shape: shape }
}
}
impl Into<LayerType> for ReshapeConfig {
fn into(self) -> LayerType {
LayerType::Reshape(self)
}
}