From 01ad9cb4041c2d2907ec63d5a4f765140feea5cd Mon Sep 17 00:00:00 2001 From: carrotflakes Date: Sat, 2 Sep 2023 02:12:55 +0900 Subject: [PATCH] Add MHAAddNorm --- src/nn/attention.rs | 50 ++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/src/nn/attention.rs b/src/nn/attention.rs index 1756d87..a88fbfc 100644 --- a/src/nn/attention.rs +++ b/src/nn/attention.rs @@ -12,9 +12,6 @@ pub struct MultiHeadAttention { key_proj: Linear, query_proj: Linear, value_proj: Linear, - - dense: Linear, - norm: Normalization, // TODO: dropout } @@ -22,11 +19,11 @@ impl MultiHeadAttention { pub fn new( embed_dim: usize, num_heads: usize, - layer_norm_eps: f32, w: impl Initializer + Scope, b: impl Initializer + Scope, - opt: impl Optimizer + Clone, ) -> Self { + assert!(embed_dim % num_heads == 0); + MultiHeadAttention { head_dim: embed_dim / num_heads, num_heads, @@ -48,13 +45,6 @@ impl MultiHeadAttention { w.scope("value_proj"), Some(b.scope("value_proj")), ), - dense: Linear::new( - embed_dim, - embed_dim, - w.scope("dense"), - Some(b.scope("dense")), - ), - norm: Normalization::new(vec![1], layer_norm_eps, opt), } } @@ -82,10 +72,9 @@ impl MultiHeadAttention { let attention_value = attention.matmul(&value); // (N, num_heads, L, head_dim) -> (N, L, num_heads * head_dim) - let attention_value = self.merge_heads(attention_value); + let y = self.merge_heads(attention_value); - let y = self.dense.call(attention_value, train); - self.norm.call(&y + x, train) + y } fn separate_heads(&self, features: ComputedNDA) -> ComputedNDA { @@ -124,6 +113,35 @@ impl MultiHeadAttention { } } +pub struct MHAAddNorm { + attention: MultiHeadAttention, + dense: Linear, + norm: Normalization, +} + +impl MHAAddNorm { + pub fn new( + dim: usize, + num_heads: usize, + layer_norm_eps: f32, + w: impl Initializer + Scope, + b: impl Initializer + Scope, + opt: impl Optimizer + Clone, + ) -> Self { + Self { + attention: MultiHeadAttention::new(dim, num_heads, w.scope("mha"), b.scope("mha")), + dense: Linear::new(dim, dim, w.scope("dense"), Some(b.scope("dense"))), + norm: Normalization::new(vec![1], layer_norm_eps, opt), + } + } + + pub fn call(&self, x: &ComputedNDA, attn_mask: &ComputedNDA, train: bool) -> ComputedNDA { + let y = self.attention.call(x, attn_mask, train); + let y = self.dense.call(y, train); + self.norm.call(&y + x, train) + } +} + #[test] fn test() { use ndarray_rand::rand_distr::Uniform; @@ -133,7 +151,7 @@ fn test() { optimizers::Adam::new(), ); - let mha = MultiHeadAttention::new( + let mha = MHAAddNorm::new( 64, 4, 1e-5,