From 261736a078e217b82f169417fb9743c90936432d Mon Sep 17 00:00:00 2001 From: carrotflakes Date: Tue, 12 Mar 2024 20:54:32 +0900 Subject: [PATCH] Fix normalization --- examples/gru.rs | 7 ++++--- src/nn/attention.rs | 2 +- src/nn/normalization.rs | 34 ++++++++++++++++++++++++---------- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/examples/gru.rs b/examples/gru.rs index 67c4959..778d7a3 100644 --- a/examples/gru.rs +++ b/examples/gru.rs @@ -37,13 +37,16 @@ fn main() { println!("data size: {}", data.len()); println!("vocab size: {}", vocab_size); + let embedding_size = 64; + let state_size = 128; + // let optimizer = optimizers::SGDOptimizer::new(); // let lr = 0.1; let optimizer = Arc::new(Mutex::new(optimizers::Adam::new())); // let optimizer = optimizers::WithRegularization::new(optimizer, regularizers::L2::new(0.001)); let lr = 0.0001; - let norm = normalization::Normalization::new(vec![0, 1], 0.001, optimizers::Adam::new()); + let norm = normalization::Normalization::new(vec![1], vec![state_size], 0.001, optimizers::Adam::new()); let init_kernel = InitializerWithSharedOptimizer::new( RandomInitializer::new(Normal::new(0., 0.1).unwrap()), @@ -54,8 +57,6 @@ fn main() { optimizer.clone(), ); - let embedding_size = 64; - let state_size = 128; let embedding = Embedding::new(embedding_size, vocab_size, init_kernel.scope("embedding")); let model = Gru::new(embedding_size, state_size, init_kernel.scope("gru")); let linear = Linear::new( diff --git a/src/nn/attention.rs b/src/nn/attention.rs index a88fbfc..0b0dd6d 100644 --- a/src/nn/attention.rs +++ b/src/nn/attention.rs @@ -131,7 +131,7 @@ impl MHAAddNorm { Self { attention: MultiHeadAttention::new(dim, num_heads, w.scope("mha"), b.scope("mha")), dense: Linear::new(dim, dim, w.scope("dense"), Some(b.scope("dense"))), - norm: Normalization::new(vec![1], layer_norm_eps, opt), + norm: Normalization::new(vec![1], vec![dim], layer_norm_eps, opt), } } diff --git a/src/nn/normalization.rs b/src/nn/normalization.rs index 0564a7e..decda26 100644 --- a/src/nn/normalization.rs +++ b/src/nn/normalization.rs @@ -5,18 +5,31 @@ use crate::{ndarray_util::map_axes_keep_dim, *}; // TODO: infer time pub struct Normalization { - pub axes: Vec, + pub axis: Vec, pub gamma: ParamNDA, pub beta: ParamNDA, pub eps: f32, // 0.001 } impl Normalization { - pub fn new(axes: Vec, eps: f32, optimizer: impl Optimizer + Clone) -> Self { + pub fn new( + axis: Vec, + bias_shape: Vec, + eps: f32, + optimizer: impl Optimizer + Clone, + ) -> Self { Self { - axes, - gamma: ParamNDA::new(scalar(1.0), "normalization".into(), optimizer.clone()), - beta: ParamNDA::new(scalar(0.0), "normalization".into(), optimizer.clone()), + axis, + gamma: ParamNDA::new( + NDArray::from_elem(bias_shape.clone(), 1.0), + "normalization".into(), + optimizer.clone(), + ), + beta: ParamNDA::new( + NDArray::from_elem(bias_shape, 0.0), + "normalization".into(), + optimizer.clone(), + ), eps, } } @@ -27,10 +40,11 @@ impl Layer for Normalization { type Output = ComputedNDA; fn call(&self, x: Self::Input, _train: bool) -> Self::Output { - let mean = map_axes_keep_dim(&*x, &self.axes, |x| x.mean_axis(Axis(1)).unwrap()); - let var = map_axes_keep_dim(&*x, &self.axes, |x| x.var_axis(Axis(1), 1.0)); + let mean = map_axes_keep_dim(&*x, &self.axis, |x| x.mean_axis(Axis(1)).unwrap()); + let var = map_axes_keep_dim(&*x, &self.axis, |x| x.var_axis(Axis(1), 1.0)); + let shape = x.shape().to_vec(); (x - ComputedNDA::new(mean.into_ndarray())) - * (self.gamma.get() + * (self.gamma.get().broadcast(shape) / ComputedNDA::new((var + self.eps).map(|x| x.sqrt()).into_ndarray())) + self.beta.get() } @@ -43,7 +57,7 @@ impl Layer for Normalization { #[test] fn test() { let x = ComputedNDA::new(ndarray::array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0].into_ndarray()); - let bn = Normalization::new(vec![0], 0.001, optimizers::Adam::new()); + let bn = Normalization::new(vec![0], vec![6], 0.001, optimizers::Adam::new()); let y = bn.call(x, false); assert!((y.mean().unwrap() - 0.0).abs() < 1e-6); assert!((y.var(1.0) - 1.0).abs() < 0.01); @@ -56,7 +70,7 @@ fn test() { .unwrap() .into_ndarray(), ); - let bn = Normalization::new(vec![1, 2], 0.001, optimizers::Adam::new()); + let bn = Normalization::new(vec![1, 2], vec![3, 4, 1], 0.001, optimizers::Adam::new()); let y = bn.call(x.clone(), false); dbg!(&*y); assert_eq!(x.shape(), y.shape());