Trait coaster_nn::LRN

source ·
pub trait LRN<F>: NN<F> {
    // Required methods
    fn new_lrn_config(
        &self,
        n: u32,
        alpha: f64,
        beta: f64,
        k: f64
    ) -> Result<Self::CLRN, Error>;
    fn lrn(
        &self,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
        config: &Self::CLRN
    ) -> Result<(), Error>;
    fn lrn_grad(
        &self,
        x: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
        config: &Self::CLRN
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support Local Response Normalization operations.

Required Methods§

source

fn new_lrn_config( &self, n: u32, alpha: f64, beta: f64, k: f64 ) -> Result<Self::CLRN, Error>

Creates a new (Local Response Normalization) LRNConfig, which needs to be passed to further LRN Operations.

source

fn lrn( &self, x: &SharedTensor<F>, result: &mut SharedTensor<F>, config: &Self::CLRN ) -> Result<(), Error>

Computes a [LRN][lrn] over the input Tensor x. [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network

Saves the result to result.

source

fn lrn_grad( &self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, config: &Self::CLRN ) -> Result<(), Error>

Computes the gradient of a [LRN][lrn] over the input Tensor x. [lrn]: https://en.wikipedia.org/wiki/lrnal_neural_network

Saves the result to result_diff.

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl<T> LRN<T> for Backend<Cuda>
where T: Float + Default + DataTypeInfo,

source§

fn new_lrn_config( &self, n: u32, alpha: f64, beta: f64, k: f64 ) -> Result<Self::CLRN, Error>

source§

fn lrn( &self, x: &SharedTensor<T>, result: &mut SharedTensor<T>, config: &Self::CLRN ) -> Result<(), Error>

source§

fn lrn_grad( &self, x: &SharedTensor<T>, x_diff: &SharedTensor<T>, result: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, config: &Self::CLRN ) -> Result<(), Error>

Implementors§