-
Notifications
You must be signed in to change notification settings - Fork 7
/
model.py
94 lines (69 loc) · 3.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import tensorflow as tf
import numpy as np
class attention(tf.keras.layers.Layer):
def __init__(self, dim):
super(attention, self).__init__()
self.dim = dim
self.dense_s = tf.keras.layers.Dense(self.dim)
self.dense_h = tf.keras.layers.Dense(self.dim)
def call(self, inputs):
# Split inputs into attentions vectors and inputs from the LSTM output
s = inputs[0] # (..., depth_s)
h = inputs[1] # (..., seq_len, depth_h)
# Linear FC
s_fi = self.dense_s(s) # (..., F)
h_psi = self.dense_h(h) # (..., seq_len, F)
# Linear blendning < φ(s_i), ψ(h_u) >
# Forced seq_len of 1 since s should always be a single vector per batch
e = tf.matmul(s_fi, h_psi, transpose_b=True) # (..., 1, seq_len)
# Softmax vector
alpha = tf.nn.softmax(e) # (..., 1, seq_len)
# Context vector
c = tf.matmul(alpha, h) # (..., 1, depth_h)
c = tf.squeeze(c, 1) # (..., depth_h)
return c
class att_rnn( tf.keras.layers.Layer):
def __init__(self, units,):
super(att_rnn, self).__init__()
self.units = units
self.state_size = [self.units, self.units]
self.attention_context = attention(self.units)
self.rnn = tf.keras.layers.LSTMCell(self.units)
self.rnn2 = tf.keras.layers.LSTMCell(self.units)
def call(self, inputs, states, constants):
#
h = tf.squeeze(constants, axis=0)
s = self.rnn(inputs=inputs, states=states) # [(..., F), [(..., F), (..., F)]]
s = self.rnn2(inputs=s[0], states=s[1])[1] # [(..., F), (..., F)]
c = self.attention_context([s[0], h]) # (..., F)
out = tf.keras.layers.concatenate([s[0], c], axis=-1) # (..., F*2)
return out, [c, s[1]]
class pBLSTM(tf.keras.layers.Layer):
def __init__(self, dim):
super(pBLSTM, self).__init__()
self.dim = dim
self.LSTM = tf.keras.layers.LSTM(self.dim, return_sequences=True)
self.bidi_LSTM = tf.keras.layers.Bidirectional(self.LSTM)
@tf.function
def call(self, inputs):
y = self.bidi_LSTM(inputs) # (..., seq_len, 2*dim)
if tf.shape(inputs)[1] % 2 == 1:
y = tf.keras.layers.ZeroPadding1D(padding=(0, 1))(y)
y = tf.keras.layers.Reshape(target_shape=(-1, int(self.dim*4)))(y) # (..., seq_len//2, 4*dim)
return y
def LAS(dim, f_1, no_tokens):
input_1 = tf.keras.Input(shape=(None, f_1))
input_2 = tf.keras.Input(shape=(None, no_tokens))
#Listen; Lower resoultion by 8x
x = pBLSTM( dim//2 )(input_1) # (..., audio_len//2, dim*2)
x = pBLSTM( dim//2 )(x) # (..., audio_len//4, dim*2)
x = pBLSTM( dim//4 )(x) # (..., audio_len//8, dim)
#Attend
x = tf.keras.layers.RNN(att_rnn(dim), return_sequences=True)(input_2, constants=x) # (..., seq_len, dim*2)
#Spell
x = tf.keras.layers.Dense(dim, activation="relu")(x) # (..., seq_len, dim)
x = tf.keras.layers.Dense(no_tokens, activation="softmax")(x) # (..., seq_len, no_tokens)
model = tf.keras.Model(inputs=[input_1, input_2], outputs=x)
return model
model = LAS(256, 256, 16)
model.compile(loss="mse", optimizer="adam")