Trait coaster_nn::Rnn

source ·
pub trait Rnn<F>: NN<F> {
    // Required methods
    fn new_rnn_config(
        &self,
        src: &SharedTensor<F>,
        dropout_probability: Option<f32>,
        dropout_seed: Option<u64>,
        sequence_length: i32,
        network_mode: RnnNetworkMode,
        input_mode: RnnInputMode,
        direction_mode: DirectionMode,
        algorithm: RnnAlgorithm,
        hidden_size: i32,
        num_layers: i32,
        batch_size: i32
    ) -> Result<Self::CRNN, Error>;
    fn generate_rnn_weight_description(
        &self,
        rnn_config: &Self::CRNN,
        input_size: i32
    ) -> Result<Vec<usize>, Error>;
    fn rnn_forward(
        &self,
        src: &SharedTensor<F>,
        output: &mut SharedTensor<F>,
        rnn_config: &Self::CRNN,
        weight: &SharedTensor<F>,
        workspace: &mut SharedTensor<u8>
    ) -> Result<(), Error>;
    fn rnn_backward_data(
        &self,
        src: &SharedTensor<F>,
        src_gradient: &mut SharedTensor<F>,
        output: &SharedTensor<F>,
        output_gradient: &SharedTensor<F>,
        rnn_config: &Self::CRNN,
        weight: &SharedTensor<F>,
        workspace: &mut SharedTensor<u8>
    ) -> Result<(), Error>;
    fn rnn_backward_weights(
        &self,
        src: &SharedTensor<F>,
        output: &SharedTensor<F>,
        filter: &mut SharedTensor<F>,
        rnn_config: &Self::CRNN,
        workspace: &mut SharedTensor<u8>
    ) -> Result<(), Error>;
}
Expand description

Provide the functionality for a Backend to support RNN operations

Required Methods§

source

fn new_rnn_config( &self, src: &SharedTensor<F>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32 ) -> Result<Self::CRNN, Error>

Create a RnnConfig

source

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, input_size: i32 ) -> Result<Vec<usize>, Error>

Generate Weights for RNN

source

fn rnn_forward( &self, src: &SharedTensor<F>, output: &mut SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

Train a LSTM Network and Return Results

§Arguments
  • weight_desc Previously initialised FilterDescriptor for Weights
source

fn rnn_backward_data( &self, src: &SharedTensor<F>, src_gradient: &mut SharedTensor<F>, output: &SharedTensor<F>, output_gradient: &SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

Calculates RNN Gradients for Input/Hidden/Cell

source

fn rnn_backward_weights( &self, src: &SharedTensor<F>, output: &SharedTensor<F>, filter: &mut SharedTensor<F>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

Calculates RNN Gradients for Weights

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl<T> Rnn<T> for Backend<Cuda>
where T: Float + DataTypeInfo,

source§

fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

Train and Output a RNN Network

source§

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, input_size: i32 ) -> Result<Vec<usize>, Error>

source§

fn new_rnn_config( &self, src: &SharedTensor<T>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32 ) -> Result<Self::CRNN, Error>

source§

fn rnn_backward_data( &self, src: &SharedTensor<T>, src_gradient: &mut SharedTensor<T>, output: &SharedTensor<T>, output_gradient: &SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

source§

fn rnn_backward_weights( &self, src: &SharedTensor<T>, output: &SharedTensor<T>, filter: &mut SharedTensor<T>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

source§

impl<T> Rnn<T> for Backend<Native>
where T: Float + Default + Copy + PartialOrd + Bounded,

source§

fn new_rnn_config( &self, src: &SharedTensor<T>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32 ) -> Result<Self::CRNN, Error>

source§

fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, input_size: i32 ) -> Result<Vec<usize>, Error>

source§

fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

source§

fn rnn_backward_data( &self, src: &SharedTensor<T>, src_gradient: &mut SharedTensor<T>, output: &SharedTensor<T>, output_gradient: &SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

source§

fn rnn_backward_weights( &self, src: &SharedTensor<T>, output: &SharedTensor<T>, filter: &mut SharedTensor<T>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>

Implementors§