Trait coaster_nn::Tanh

source ·
pub trait Tanh<F>: NN<F> {
    // Required methods
    fn tanh(
        &self,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>
    ) -> Result<(), Error>;
    fn tanh_grad(
        &self,
        x: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support TanH operations.

Required Methods§

source

fn tanh( &self, x: &SharedTensor<F>, result: &mut SharedTensor<F> ) -> Result<(), Error>

Computes the [hyperbolic Tangent][tanh] over the input Tensor x. [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function

Saves the result to result.

source

fn tanh_grad( &self, x: &SharedTensor<F>, x_diff: &SharedTensor<F>, result: &SharedTensor<F>, result_diff: &mut SharedTensor<F> ) -> Result<(), Error>

Computes the gradient of [hyperbolic Tangent][tanh] over the input Tensor x. [tanh]: https://en.wikipedia.org/wiki/Hyperbolic_function

Saves the result to result_diff.

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl Tanh<f32> for Backend<Native>

source§

fn tanh( &self, x: &SharedTensor<f32>, result: &mut SharedTensor<f32> ) -> Result<(), Error>

source§

fn tanh_grad( &self, x: &SharedTensor<f32>, x_diff: &SharedTensor<f32>, result: &SharedTensor<f32>, result_diff: &mut SharedTensor<f32> ) -> Result<(), Error>

source§

impl Tanh<f64> for Backend<Native>

source§

fn tanh( &self, x: &SharedTensor<f64>, result: &mut SharedTensor<f64> ) -> Result<(), Error>

source§

fn tanh_grad( &self, x: &SharedTensor<f64>, x_diff: &SharedTensor<f64>, result: &SharedTensor<f64>, result_diff: &mut SharedTensor<f64> ) -> Result<(), Error>

source§

impl<T> Tanh<T> for Backend<Cuda>
where T: Float + Default + DataTypeInfo,

source§

fn tanh( &self, x: &SharedTensor<T>, result: &mut SharedTensor<T> ) -> Result<(), Error>

source§

fn tanh_grad( &self, x: &SharedTensor<T>, x_diff: &SharedTensor<T>, result: &SharedTensor<T>, result_diff: &mut SharedTensor<T> ) -> Result<(), Error>

Implementors§