Skip to content

Commit

Permalink
Fix normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
carrotflakes committed Mar 12, 2024
1 parent 5ed3b2e commit 261736a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
7 changes: 4 additions & 3 deletions examples/gru.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ fn main() {
println!("data size: {}", data.len());
println!("vocab size: {}", vocab_size);

let embedding_size = 64;
let state_size = 128;

// let optimizer = optimizers::SGDOptimizer::new();
// let lr = 0.1;
let optimizer = Arc::new(Mutex::new(optimizers::Adam::new()));
// let optimizer = optimizers::WithRegularization::new(optimizer, regularizers::L2::new(0.001));
let lr = 0.0001;

let norm = normalization::Normalization::new(vec![0, 1], 0.001, optimizers::Adam::new());
let norm = normalization::Normalization::new(vec![1], vec![state_size], 0.001, optimizers::Adam::new());

let init_kernel = InitializerWithSharedOptimizer::new(
RandomInitializer::new(Normal::new(0., 0.1).unwrap()),
Expand All @@ -54,8 +57,6 @@ fn main() {
optimizer.clone(),
);

let embedding_size = 64;
let state_size = 128;
let embedding = Embedding::new(embedding_size, vocab_size, init_kernel.scope("embedding"));
let model = Gru::new(embedding_size, state_size, init_kernel.scope("gru"));
let linear = Linear::new(
Expand Down
2 changes: 1 addition & 1 deletion src/nn/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl MHAAddNorm {
Self {
attention: MultiHeadAttention::new(dim, num_heads, w.scope("mha"), b.scope("mha")),
dense: Linear::new(dim, dim, w.scope("dense"), Some(b.scope("dense"))),
norm: Normalization::new(vec![1], layer_norm_eps, opt),
norm: Normalization::new(vec![1], vec![dim], layer_norm_eps, opt),
}
}

Expand Down
34 changes: 24 additions & 10 deletions src/nn/normalization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,31 @@ use crate::{ndarray_util::map_axes_keep_dim, *};
// TODO: infer time

pub struct Normalization {
pub axes: Vec<usize>,
pub axis: Vec<usize>,
pub gamma: ParamNDA,
pub beta: ParamNDA,
pub eps: f32, // 0.001
}

impl Normalization {
pub fn new(axes: Vec<usize>, eps: f32, optimizer: impl Optimizer<NDArray> + Clone) -> Self {
pub fn new(
axis: Vec<usize>,
bias_shape: Vec<usize>,
eps: f32,
optimizer: impl Optimizer<NDArray> + Clone,
) -> Self {
Self {
axes,
gamma: ParamNDA::new(scalar(1.0), "normalization".into(), optimizer.clone()),
beta: ParamNDA::new(scalar(0.0), "normalization".into(), optimizer.clone()),
axis,
gamma: ParamNDA::new(
NDArray::from_elem(bias_shape.clone(), 1.0),
"normalization".into(),
optimizer.clone(),
),
beta: ParamNDA::new(
NDArray::from_elem(bias_shape, 0.0),
"normalization".into(),
optimizer.clone(),
),
eps,
}
}
Expand All @@ -27,10 +40,11 @@ impl Layer for Normalization {
type Output = ComputedNDA;

fn call(&self, x: Self::Input, _train: bool) -> Self::Output {
let mean = map_axes_keep_dim(&*x, &self.axes, |x| x.mean_axis(Axis(1)).unwrap());
let var = map_axes_keep_dim(&*x, &self.axes, |x| x.var_axis(Axis(1), 1.0));
let mean = map_axes_keep_dim(&*x, &self.axis, |x| x.mean_axis(Axis(1)).unwrap());
let var = map_axes_keep_dim(&*x, &self.axis, |x| x.var_axis(Axis(1), 1.0));
let shape = x.shape().to_vec();
(x - ComputedNDA::new(mean.into_ndarray()))
* (self.gamma.get()
* (self.gamma.get().broadcast(shape)
/ ComputedNDA::new((var + self.eps).map(|x| x.sqrt()).into_ndarray()))
+ self.beta.get()
}
Expand All @@ -43,7 +57,7 @@ impl Layer for Normalization {
#[test]
fn test() {
let x = ComputedNDA::new(ndarray::array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0].into_ndarray());
let bn = Normalization::new(vec![0], 0.001, optimizers::Adam::new());
let bn = Normalization::new(vec![0], vec![6], 0.001, optimizers::Adam::new());
let y = bn.call(x, false);
assert!((y.mean().unwrap() - 0.0).abs() < 1e-6);
assert!((y.var(1.0) - 1.0).abs() < 0.01);
Expand All @@ -56,7 +70,7 @@ fn test() {
.unwrap()
.into_ndarray(),
);
let bn = Normalization::new(vec![1, 2], 0.001, optimizers::Adam::new());
let bn = Normalization::new(vec![1, 2], vec![3, 4, 1], 0.001, optimizers::Adam::new());
let y = bn.call(x.clone(), false);
dbg!(&*y);
assert_eq!(x.shape(), y.shape());
Expand Down

0 comments on commit 261736a

Please sign in to comment.