Skip to content

Commit

Permalink
Add MHAAddNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
carrotflakes committed Sep 1, 2023
1 parent f7cc1b2 commit 01ad9cb
Showing 1 changed file with 34 additions and 16 deletions.
50 changes: 34 additions & 16 deletions src/nn/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@ pub struct MultiHeadAttention {
key_proj: Linear,
query_proj: Linear,
value_proj: Linear,

dense: Linear,
norm: Normalization,
// TODO: dropout
}

impl MultiHeadAttention {
pub fn new(
embed_dim: usize,
num_heads: usize,
layer_norm_eps: f32,
w: impl Initializer<ParamNDA> + Scope,
b: impl Initializer<ParamNDA> + Scope,
opt: impl Optimizer<NDArray> + Clone,
) -> Self {
assert!(embed_dim % num_heads == 0);

MultiHeadAttention {
head_dim: embed_dim / num_heads,
num_heads,
Expand All @@ -48,13 +45,6 @@ impl MultiHeadAttention {
w.scope("value_proj"),
Some(b.scope("value_proj")),
),
dense: Linear::new(
embed_dim,
embed_dim,
w.scope("dense"),
Some(b.scope("dense")),
),
norm: Normalization::new(vec![1], layer_norm_eps, opt),
}
}

Expand Down Expand Up @@ -82,10 +72,9 @@ impl MultiHeadAttention {
let attention_value = attention.matmul(&value);

// (N, num_heads, L, head_dim) -> (N, L, num_heads * head_dim)
let attention_value = self.merge_heads(attention_value);
let y = self.merge_heads(attention_value);

let y = self.dense.call(attention_value, train);
self.norm.call(&y + x, train)
y
}

fn separate_heads(&self, features: ComputedNDA) -> ComputedNDA {
Expand Down Expand Up @@ -124,6 +113,35 @@ impl MultiHeadAttention {
}
}

pub struct MHAAddNorm {
attention: MultiHeadAttention,
dense: Linear,
norm: Normalization,
}

impl MHAAddNorm {
pub fn new(
dim: usize,
num_heads: usize,
layer_norm_eps: f32,
w: impl Initializer<ParamNDA> + Scope,
b: impl Initializer<ParamNDA> + Scope,
opt: impl Optimizer<NDArray> + Clone,
) -> Self {
Self {
attention: MultiHeadAttention::new(dim, num_heads, w.scope("mha"), b.scope("mha")),
dense: Linear::new(dim, dim, w.scope("dense"), Some(b.scope("dense"))),
norm: Normalization::new(vec![1], layer_norm_eps, opt),
}
}

pub fn call(&self, x: &ComputedNDA, attn_mask: &ComputedNDA, train: bool) -> ComputedNDA {
let y = self.attention.call(x, attn_mask, train);
let y = self.dense.call(y, train);
self.norm.call(&y + x, train)
}
}

#[test]
fn test() {
use ndarray_rand::rand_distr::Uniform;
Expand All @@ -133,7 +151,7 @@ fn test() {
optimizers::Adam::new(),
);

let mha = MultiHeadAttention::new(
let mha = MHAAddNorm::new(
64,
4,
1e-5,
Expand Down

0 comments on commit 01ad9cb

Please sign in to comment.