Trait coaster_nn::Rnn
source · pub trait Rnn<F>: NN<F> {
// Required methods
fn new_rnn_config(
&self,
src: &SharedTensor<F>,
dropout_probability: Option<f32>,
dropout_seed: Option<u64>,
sequence_length: i32,
network_mode: RnnNetworkMode,
input_mode: RnnInputMode,
direction_mode: DirectionMode,
algorithm: RnnAlgorithm,
hidden_size: i32,
num_layers: i32,
batch_size: i32
) -> Result<Self::CRNN, Error>;
fn generate_rnn_weight_description(
&self,
rnn_config: &Self::CRNN,
input_size: i32
) -> Result<Vec<usize>, Error>;
fn rnn_forward(
&self,
src: &SharedTensor<F>,
output: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>;
fn rnn_backward_data(
&self,
src: &SharedTensor<F>,
src_gradient: &mut SharedTensor<F>,
output: &SharedTensor<F>,
output_gradient: &SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>;
fn rnn_backward_weights(
&self,
src: &SharedTensor<F>,
output: &SharedTensor<F>,
filter: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>;
}
Expand description
Provide the functionality for a Backend to support RNN operations
Required Methods§
sourcefn new_rnn_config(
&self,
src: &SharedTensor<F>,
dropout_probability: Option<f32>,
dropout_seed: Option<u64>,
sequence_length: i32,
network_mode: RnnNetworkMode,
input_mode: RnnInputMode,
direction_mode: DirectionMode,
algorithm: RnnAlgorithm,
hidden_size: i32,
num_layers: i32,
batch_size: i32
) -> Result<Self::CRNN, Error>
fn new_rnn_config( &self, src: &SharedTensor<F>, dropout_probability: Option<f32>, dropout_seed: Option<u64>, sequence_length: i32, network_mode: RnnNetworkMode, input_mode: RnnInputMode, direction_mode: DirectionMode, algorithm: RnnAlgorithm, hidden_size: i32, num_layers: i32, batch_size: i32 ) -> Result<Self::CRNN, Error>
Create a RnnConfig
sourcefn generate_rnn_weight_description(
&self,
rnn_config: &Self::CRNN,
input_size: i32
) -> Result<Vec<usize>, Error>
fn generate_rnn_weight_description( &self, rnn_config: &Self::CRNN, input_size: i32 ) -> Result<Vec<usize>, Error>
Generate Weights for RNN
sourcefn rnn_forward(
&self,
src: &SharedTensor<F>,
output: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>
fn rnn_forward( &self, src: &SharedTensor<F>, output: &mut SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>
Train a LSTM Network and Return Results
§Arguments
weight_desc
Previously initialised FilterDescriptor for Weights
sourcefn rnn_backward_data(
&self,
src: &SharedTensor<F>,
src_gradient: &mut SharedTensor<F>,
output: &SharedTensor<F>,
output_gradient: &SharedTensor<F>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<F>,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>
fn rnn_backward_data( &self, src: &SharedTensor<F>, src_gradient: &mut SharedTensor<F>, output: &SharedTensor<F>, output_gradient: &SharedTensor<F>, rnn_config: &Self::CRNN, weight: &SharedTensor<F>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>
Calculates RNN Gradients for Input/Hidden/Cell
sourcefn rnn_backward_weights(
&self,
src: &SharedTensor<F>,
output: &SharedTensor<F>,
filter: &mut SharedTensor<F>,
rnn_config: &Self::CRNN,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>
fn rnn_backward_weights( &self, src: &SharedTensor<F>, output: &SharedTensor<F>, filter: &mut SharedTensor<F>, rnn_config: &Self::CRNN, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>
Calculates RNN Gradients for Weights
Object Safety§
This trait is not object safe.
Implementations on Foreign Types§
source§impl<T> Rnn<T> for Backend<Cuda>where
T: Float + DataTypeInfo,
impl<T> Rnn<T> for Backend<Cuda>where
T: Float + DataTypeInfo,
source§fn rnn_forward(
&self,
src: &SharedTensor<T>,
output: &mut SharedTensor<T>,
rnn_config: &Self::CRNN,
weight: &SharedTensor<T>,
workspace: &mut SharedTensor<u8>
) -> Result<(), Error>
fn rnn_forward( &self, src: &SharedTensor<T>, output: &mut SharedTensor<T>, rnn_config: &Self::CRNN, weight: &SharedTensor<T>, workspace: &mut SharedTensor<u8> ) -> Result<(), Error>
Train and Output a RNN Network