1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
//! Provides configuration of weights and their initialization.
use crate::capnp_util::*;
use crate::co::{ITensorDesc, SharedTensor};
use crate::juice_capnp::weight_config as capnp_config;
use crate::util::native_backend;
use rand::{self, prelude::*};
#[derive(Debug, Clone)]
/// Specifies training configuration for a weight blob.
pub struct WeightConfig {
/// The name of the weight blob -- useful for sharing weights among
/// layers, but never required otherwise. To share a weight between two
/// layers, give it a (non-empty) name.
///
/// Default: ""
pub name: String,
/// Whether to require shared weights to have the same shape, or just the same
/// count
///
/// Default: DimCheckMode::Strict
pub share_mode: DimCheckMode,
/// The multiplier on the global learning rate for this parameter.
///
/// Default: 1.0f32
pub lr_mult: Option<f32>,
/// The multiplier on the global weight decay for this parameter.
///
/// Default: 1.0f32
pub decay_mult: Option<f32>,
/// The filler that initializes the weights in the weight blob.
///
/// Default: None
pub filler: Option<FillerType>,
}
impl Default for WeightConfig {
fn default() -> WeightConfig {
WeightConfig {
name: "".to_owned(),
share_mode: DimCheckMode::Strict,
lr_mult: None,
decay_mult: None,
filler: None,
}
}
}
impl WeightConfig {
/// Checks dimensions of two blobs according to the `share_mode`.
/// Returns an error if there is a count/shape mismatch.
pub fn check_dimensions<T>(
&self,
tensor_one: &SharedTensor<T>,
tensor_two: &SharedTensor<T>,
param_name: String,
owner_name: String,
layer_name: String,
) -> Result<(), String> {
match self.share_mode {
// Permissive dimension checking -- only check counts are the same.
DimCheckMode::Permissive => {
if tensor_one.desc().size() != tensor_two.desc().size() {
return Err(format!(
"Cannot share weight '{}' owned by layer '{}' with layer '{}';
count mismatch.
Owner layer weight shape is {:?};
Sharing layer weight shape is {:?}",
param_name,
owner_name,
layer_name,
tensor_two.desc(),
tensor_one.desc()
));
}
}
// Strict dimension checking -- all dims must be the same.
DimCheckMode::Strict => {
if tensor_one.desc() != tensor_two.desc() {
return Err(format!(
"Cannot share weight '{}' owned by layer '{}' with layer '{}';
shape mismatch.
Owner layer weight shape is {:?};
Sharing layer expects weight shape {:?}",
param_name,
owner_name,
layer_name,
tensor_two.desc(),
tensor_one.desc()
));
}
}
}
Ok(())
}
/// The multiplier on the global learning rate for this weight blob.
pub fn lr_mult(&self) -> f32 {
match self.lr_mult {
Some(val) => val,
None => 1.0f32,
}
}
/// The multiplier on the global weight decay for this weight blob.
pub fn decay_mult(&self) -> f32 {
match self.decay_mult {
Some(val) => val,
None => 1.0f32,
}
}
}
impl<'a> CapnpWrite<'a> for WeightConfig {
type Builder = capnp_config::Builder<'a>;
/// Write the WeightConfig into a capnp message.
fn write_capnp(&self, builder: &mut Self::Builder) {
// TODO: incomplete since WeightConfig isn't really used internally in Juice at the moment.
builder.reborrow().set_name(&self.name);
}
}
impl<'a> CapnpRead<'a> for WeightConfig {
type Reader = capnp_config::Reader<'a>;
fn read_capnp(reader: Self::Reader) -> Self {
// TODO: incomplete since WeightConfig isn't really used internally in Juice at the moment.
let name = reader.get_name().unwrap().to_owned();
WeightConfig {
name: name,
..Self::default()
}
}
}
#[derive(Debug, Copy, Clone)]
/// Enum for specifing the shared weights behaviour
pub enum DimCheckMode {
/// Strict requires that shapes match.
Strict,
/// Permissive requires only the count of weights to match.
Permissive,
}
#[derive(Debug, Copy, Clone)]
/// Enum for specifing the type of Filler.
pub enum FillerType {
/// Fills the weight blob with a constant `value` (all values are the same).
Constant {
/// The value that will be used to fill the blob.
value: f32,
},
/// Fills the weight blobs based on the paper:
///
/// `[Bengio and Glorot 2010]: Understanding the difficulty of training deep feedforward neural networks.`
///
/// Also known as Xavier filler.
Glorot {
/// Number of input nodes for each output.
input_size: usize,
/// Number of output nodes for each input.
output_size: usize,
},
}
impl FillerType {
/// Uses a filler as specified by this FillerType to fill the values in a SharedTensor
///
/// This filling of weights is usually done directly after creation of the weight blob.
pub fn fill(&self, weight: &mut SharedTensor<f32>) {
let native = native_backend();
let native_device = native.device();
match *self {
FillerType::Constant { value } => Self::fill_constant(weight, value),
FillerType::Glorot {
input_size,
output_size,
} => Self::fill_glorot(weight, input_size, output_size),
}
}
/// Directly use the [Constant Filler](#variant.Constant).
pub fn fill_constant(weight: &mut SharedTensor<f32>, value: f32) {
let native = native_backend();
let native_weight = weight.write_only(native.device()).unwrap();
for e in native_weight.as_mut_slice::<f32>() {
*e = value;
}
}
/// Directly use the [Glorot Filler](#variant.Glorot).
pub fn fill_glorot(weight: &mut SharedTensor<f32>, num_inputs: usize, num_outputs: usize) {
let native = native_backend();
let native_weight = weight.write_only(native.device()).unwrap();
let init_range = (6.0f32 / (num_inputs as f32 + num_outputs as f32)).sqrt();
#[cfg(feature = "deterministic")]
let mut rng = rand::rngs::StdRng::seed_from_u64(2301); // Arbitrary seed.
#[cfg(not(feature = "deterministic"))]
let mut rng = thread_rng();
let between = rand::distributions::Uniform::from(-init_range..=init_range);
for e in native_weight.as_mut_slice::<f32>() {
*e = between.sample(&mut rng);
}
}
}