Trait coaster_nn::Convolution
source · pub trait Convolution<F>: NN<F> {
// Required methods
fn new_convolution_config(
&self,
src: &SharedTensor<F>,
dest: &SharedTensor<F>,
filter: &SharedTensor<F>,
algo_fwd: ConvForwardAlgo,
algo_bwd_filter: ConvBackwardFilterAlgo,
algo_bwd_data: ConvBackwardDataAlgo,
stride: &[i32],
zero_padding: &[i32]
) -> Result<Self::CC, Error>;
fn convolution(
&self,
filter: &SharedTensor<F>,
x: &SharedTensor<F>,
result: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>;
fn convolution_grad_filter(
&self,
src_data: &SharedTensor<F>,
dest_diff: &SharedTensor<F>,
filter_diff: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>;
fn convolution_grad_data(
&self,
filter: &SharedTensor<F>,
x_diff: &SharedTensor<F>,
result_diff: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>;
}
Expand description
Provides the functionality for a Backend to support Convolution operations.
Required Methods§
sourcefn new_convolution_config(
&self,
src: &SharedTensor<F>,
dest: &SharedTensor<F>,
filter: &SharedTensor<F>,
algo_fwd: ConvForwardAlgo,
algo_bwd_filter: ConvBackwardFilterAlgo,
algo_bwd_data: ConvBackwardDataAlgo,
stride: &[i32],
zero_padding: &[i32]
) -> Result<Self::CC, Error>
fn new_convolution_config( &self, src: &SharedTensor<F>, dest: &SharedTensor<F>, filter: &SharedTensor<F>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32] ) -> Result<Self::CC, Error>
Creates a new ConvolutionConfig, which needs to be passed to further convolution Operations.
sourcefn convolution(
&self,
filter: &SharedTensor<F>,
x: &SharedTensor<F>,
result: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>
fn convolution( &self, filter: &SharedTensor<F>, x: &SharedTensor<F>, result: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>
Computes a [CNN convolution][convolution] over the input Tensor x
.
[convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
Saves the result to result
.
sourcefn convolution_grad_filter(
&self,
src_data: &SharedTensor<F>,
dest_diff: &SharedTensor<F>,
filter_diff: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>
fn convolution_grad_filter( &self, src_data: &SharedTensor<F>, dest_diff: &SharedTensor<F>, filter_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>
Computes the gradient of a [CNN convolution][convolution] with respect to the filter. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
Saves the result to filter_diff
.
sourcefn convolution_grad_data(
&self,
filter: &SharedTensor<F>,
x_diff: &SharedTensor<F>,
result_diff: &mut SharedTensor<F>,
workspace: &mut SharedTensor<u8>,
config: &Self::CC
) -> Result<(), Error>
fn convolution_grad_data( &self, filter: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>
Computes the gradient of a [CNN convolution][convolution] over the input
Tensor x
with respect to the data.
[convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network
Saves the result to result_diff
.
Object Safety§
This trait is not object safe.