pub trait Convolution<F>: NN<F> {
    // Required methods
    fn new_convolution_config(
        &self,
        src: &SharedTensor<F>,
        dest: &SharedTensor<F>,
        filter: &SharedTensor<F>,
        algo_fwd: ConvForwardAlgo,
        algo_bwd_filter: ConvBackwardFilterAlgo,
        algo_bwd_data: ConvBackwardDataAlgo,
        stride: &[i32],
        zero_padding: &[i32]
    ) -> Result<Self::CC, Error>;
    fn convolution(
        &self,
        filter: &SharedTensor<F>,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>;
    fn convolution_grad_filter(
        &self,
        src_data: &SharedTensor<F>,
        dest_diff: &SharedTensor<F>,
        filter_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>;
    fn convolution_grad_data(
        &self,
        filter: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support Convolution operations.

Required Methods§

source

fn new_convolution_config( &self, src: &SharedTensor<F>, dest: &SharedTensor<F>, filter: &SharedTensor<F>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32] ) -> Result<Self::CC, Error>

Creates a new ConvolutionConfig, which needs to be passed to further convolution Operations.

source

fn convolution( &self, filter: &SharedTensor<F>, x: &SharedTensor<F>, result: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

Computes a [CNN convolution][convolution] over the input Tensor x. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to result.

source

fn convolution_grad_filter( &self, src_data: &SharedTensor<F>, dest_diff: &SharedTensor<F>, filter_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

Computes the gradient of a [CNN convolution][convolution] with respect to the filter. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to filter_diff.

source

fn convolution_grad_data( &self, filter: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

Computes the gradient of a [CNN convolution][convolution] over the input Tensor x with respect to the data. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to result_diff.

Object Safety§

This trait is not object safe.

Implementations on Foreign Types§

source§

impl<T> Convolution<T> for Backend<Cuda>
where T: Float + DataTypeInfo,

source§

fn new_convolution_config( &self, src: &SharedTensor<T>, dest: &SharedTensor<T>, filter: &SharedTensor<T>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32] ) -> Result<Self::CC, Error>

source§

fn convolution( &self, filter: &SharedTensor<T>, x: &SharedTensor<T>, result: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

source§

fn convolution_grad_filter( &self, src_data: &SharedTensor<T>, dest_diff: &SharedTensor<T>, filter_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

source§

fn convolution_grad_data( &self, filter: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

source§

impl<T> Convolution<T> for Backend<Native>
where T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,

source§

fn new_convolution_config( &self, src: &SharedTensor<T>, dest: &SharedTensor<T>, filter: &SharedTensor<T>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32] ) -> Result<Self::CC, Error>

source§

fn convolution( &self, filter: &SharedTensor<T>, x: &SharedTensor<T>, result: &mut SharedTensor<T>, _workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

source§

fn convolution_grad_filter( &self, src_data: &SharedTensor<T>, dest_diff: &SharedTensor<T>, filter_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

source§

fn convolution_grad_data( &self, filter: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC ) -> Result<(), Error>

Implementors§