-
Notifications
You must be signed in to change notification settings - Fork 48
/
model.py
493 lines (411 loc) · 22.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
'''
Tensorflow implementation of AutoInt described in:
AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks.
author: Chence Shi
email: chenceshi@pku.edu.cn
'''
import os
import numpy as np
import tensorflow as tf
from time import time
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import roc_auc_score, log_loss
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
'''
The following two functions are adapted from kyubyong park's implementation of transformer
We slightly modify the code to make it suitable for our work.(add relu, delete key masking and causality mask)
June 2017 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
'''
def normalize(inputs, epsilon=1e-8):
'''
Applies layer normalization
Args:
inputs: A tensor with 2 or more dimensions
epsilon: A floating number to prevent Zero Division
Returns:
A tensor with the same shape and data dtype
'''
inputs_shape = inputs.get_shape()
params_shape = inputs_shape[-1:]
mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
beta = tf.Variable(tf.zeros(params_shape))
gamma = tf.Variable(tf.ones(params_shape))
normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
outputs = gamma * normalized + beta
return outputs
def multihead_attention(queries,
keys,
values,
num_units=None,
num_heads=1,
dropout_keep_prob=1,
is_training=True,
has_residual=True):
if num_units is None:
num_units = queries.get_shape().as_list[-1]
# Linear projections
Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu)
K = tf.layers.dense(keys, num_units, activation=tf.nn.relu)
V = tf.layers.dense(values, num_units, activation=tf.nn.relu)
if has_residual:
V_res = tf.layers.dense(values, num_units, activation=tf.nn.relu)
# Split and concat
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)
# Multiplication
weights = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))
# Scale
weights = weights / (K_.get_shape().as_list()[-1] ** 0.5)
# Activation
weights = tf.nn.softmax(weights)
# Dropouts
weights = tf.layers.dropout(weights, rate=1-dropout_keep_prob,
training=tf.convert_to_tensor(is_training))
# Weighted sum
outputs = tf.matmul(weights, V_)
# Restore shape
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)
# Residual connection
if has_residual:
outputs += V_res
outputs = tf.nn.relu(outputs)
# Normalize
outputs = normalize(outputs)
return outputs
class AutoInt():
def __init__(self, args, feature_size, run_cnt):
#print(args.block_shape)
#print(type(args.block_shape))
self.feature_size = feature_size # denote as n, dimension of concatenated features
self.field_size = args.field_size # denote as M, number of total feature fields
self.embedding_size = args.embedding_size # denote as d, size of the feature embedding
self.blocks = args.blocks # number of the blocks
self.heads = args.heads # number of the heads
self.block_shape = args.block_shape
self.output_size = args.block_shape[-1]
self.has_residual = args.has_residual
self.has_wide = args.has_wide # whether to add wide part
self.deep_layers = args.deep_layers # whether to joint train with deep networks as described in paper
self.batch_norm = args.batch_norm
self.batch_norm_decay = args.batch_norm_decay
self.drop_keep_prob = args.dropout_keep_prob
self.l2_reg = args.l2_reg
self.epoch = args.epoch
self.batch_size = args.batch_size
self.learning_rate = args.learning_rate
self.learning_rate_wide = args.learning_rate_wide
self.optimizer_type = args.optimizer_type
self.save_path = args.save_path + str(run_cnt) + '/'
self.is_save = args.is_save
if (args.is_save == True and os.path.exists(self.save_path) == False):
os.makedirs(self.save_path)
self.verbose = args.verbose
self.random_seed = args.random_seed
self.loss_type = args.loss_type
self.eval_metric = roc_auc_score
self.best_loss = 1.0
self.greater_is_better = args.greater_is_better
self.train_result, self.valid_result = [], []
self.train_loss, self.valid_loss = [], []
self._init_graph()
def _init_graph(self):
self.graph = tf.Graph()
with self.graph.as_default():
tf.set_random_seed(self.random_seed)
self.feat_index = tf.placeholder(tf.int32, shape=[None, None],
name="feat_index") # None * M
self.feat_value = tf.placeholder(tf.float32, shape=[None, None],
name="feat_value") # None * M
self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") # None * 1
# In our implementation, the shape of dropout_keep_prob is [3], used in 3 different parts.
self.dropout_keep_prob = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_prob")
self.train_phase = tf.placeholder(tf.bool, name="train_phase")
self.weights = self._initialize_weights()
# model
self.embeddings = tf.nn.embedding_lookup(self.weights["feature_embeddings"],
self.feat_index) # None * M * d
feat_value = tf.reshape(self.feat_value, shape=[-1, self.field_size, 1])
self.embeddings = tf.multiply(self.embeddings, feat_value) # None * M * d
self.embeddings = tf.nn.dropout(self.embeddings, self.dropout_keep_prob[1]) # None * M * d
if self.has_wide:
self.y_first_order = tf.nn.embedding_lookup(self.weights["feature_bias"], self.feat_index) # None * M * 1
self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 1) # None * 1
# joint training with feedforward nn
if self.deep_layers != None:
self.y_dense = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.embedding_size])
for i in range(0, len(self.deep_layers)):
self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["layer_%d" %i]), self.weights["bias_%d"%i]) # None * layer[i]
if self.batch_norm:
self.y_dense = self.batch_norm_layer(self.y_dense, train_phase=self.train_phase, scope_bn="bn_%d" %i)
self.y_dense = tf.nn.relu(self.y_dense)
self.y_dense = tf.nn.dropout(self.y_dense, self.dropout_keep_prob[2])
self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["prediction_dense"]),
self.weights["prediction_bias_dense"], name='logits_dense') # None * 1
# ---------- main part of AutoInt-------------------
self.y_deep = self.embeddings # None * M * d
for i in range(self.blocks):
self.y_deep = multihead_attention(queries=self.y_deep,
keys=self.y_deep,
values=self.y_deep,
num_units=self.block_shape[i],
num_heads=self.heads,
dropout_keep_prob=self.dropout_keep_prob[0],
is_training=self.train_phase,
has_residual=self.has_residual)
self.flat = tf.reshape(self.y_deep,
shape=[-1, self.output_size * self.field_size])
#if self.has_wide:
# self.flat = tf.concat([self.flat, self.y_first_order], axis=1)
#if self.deep_layers != None:
# self.flat = tf.concat([self.flat, self.y_dense], axis=1)
self.out = tf.add(tf.matmul(self.flat, self.weights["prediction"]),
self.weights["prediction_bias"], name='logits') # None * 1
if self.has_wide:
self.out += self.y_first_order
if self.deep_layers != None:
self.out += self.y_dense
# ---------- Compute the loss ----------
# loss
if self.loss_type == "logloss":
self.out = tf.nn.sigmoid(self.out, name='pred')
self.loss = tf.losses.log_loss(self.label, self.out)
elif self.loss_type == "mse":
self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out))
# l2 regularization on weights
if self.l2_reg > 0:
if self.deep_layers != None:
for i in range(len(self.deep_layers)):
self.loss += tf.contrib.layers.l2_regularizer(
self.l2_reg)(self.weights["layer_%d"%i])
#self.loss += tf.contrib.layers.l2_regularizer(self.l2_reg)(self.embeddings)
#all_vars = tf.trainable_variables()
#lossL2 = tf.add_n([ tf.nn.l2_loss(v) for v in all_vars
# if 'bias' not in v.name and 'embeddings' not in v.name]) * self.l2_reg
#self.loss += lossL2
self.global_step = tf.Variable(0, name="global_step", trainable=False)
self.var1 = [v for v in tf.trainable_variables() if v.name != 'feature_bias:0']
self.var2 = [tf.trainable_variables()[1]] # self.var2 = [feature_bias]
# optimizer
# here we should use two different optimizer for wide and deep model(if we add wide part).
if self.optimizer_type == "adam":
if self.has_wide:
optimizer1 = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
beta1=0.9, beta2=0.999, epsilon=1e-8)
optimizer2 = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate_wide)
#minimize(self.loss, global_step=self.global_step)
var_list1 = self.var1
var_list2 = self.var2
grads = tf.gradients(self.loss, var_list1 + var_list2)
grads1 = grads[:len(var_list1)]
grads2 = grads[len(var_list1):]
train_op1 = optimizer1.apply_gradients(zip(grads1, var_list1), global_step=self.global_step)
train_op2 = optimizer2.apply_gradients(zip(grads2, var_list2))
self.optimizer = tf.group(train_op1, train_op2)
else:
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
beta1=0.9, beta2=0.999, epsilon=1e-8).\
minimize(self.loss, global_step=self.global_step)
elif self.optimizer_type == "adagrad":
self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate,
initial_accumulator_value=1e-8).\
minimize(self.loss)
elif self.optimizer_type == "gd":
self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).\
minimize(self.loss)
elif self.optimizer_type == "momentum":
self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).\
minimize(self.loss)
# init
self.saver = tf.train.Saver(max_to_keep=5)
init = tf.global_variables_initializer()
self.sess = self._init_session()
self.sess.run(init)
self.count_param()
def count_param(self):
k = (np.sum([np.prod(v.get_shape().as_list())
for v in tf.trainable_variables()]))
#print(tf.trainable_variables())
print("total parameters :%d" % k)
print("extra parameters : %d" % (k - self.feature_size * self.embedding_size))
def _init_session(self):
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
return tf.Session(config=config)
def _initialize_weights(self):
weights = dict()
# embeddings
weights["feature_embeddings"] = tf.Variable(
tf.random_normal([self.feature_size, self.embedding_size], 0.0, 0.01),
name="feature_embeddings") # feature_size(n) * d
if self.has_wide:
weights["feature_bias"] = tf.Variable(
tf.random_normal([self.feature_size, 1], 0.0, 0.001),
name="feature_bias") # feature_size(n) * 1
input_size = self.output_size * self.field_size
#if self.deep_layers != None:
# input_size += self.deep_layers[-1]
#if self.has_wide:
# input_size += self.field_size
# dense layers
if self.deep_layers != None:
num_layer = len(self.deep_layers)
layer0_size = self.field_size * self.embedding_size
glorot = np.sqrt(2.0 / (layer0_size + self.deep_layers[0]))
weights["layer_0"] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(layer0_size, self.deep_layers[0])), dtype=np.float32)
weights["bias_0"] = tf.Variable(np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[0])),
dtype=np.float32) # 1 * layers[0]
for i in range(1, num_layer):
glorot = np.sqrt(2.0 / (self.deep_layers[i-1] + self.deep_layers[i]))
weights["layer_%d" % i] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[i-1], self.deep_layers[i])),
dtype=np.float32) # layers[i-1] * layers[i]
weights["bias_%d" % i] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[i])),
dtype=np.float32) # 1 * layer[i]
glorot = np.sqrt(2.0 / (self.deep_layers[-1] + 1))
weights["prediction_dense"] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[-1], 1)),
dtype=np.float32, name="prediction_dense")
weights["prediction_bias_dense"] = tf.Variable(
np.random.normal(), dtype=np.float32, name="prediction_bias_dense")
#---------- prediciton weight ------------------#
glorot = np.sqrt(2.0 / (input_size + 1))
weights["prediction"] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(input_size, 1)),
dtype=np.float32, name="prediction")
weights["prediction_bias"] = tf.Variable(
np.random.normal(), dtype=np.float32, name="prediction_bias")
return weights
def batch_norm_layer(self, x, train_phase, scope_bn):
bn_train = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
is_training=True, reuse=None, trainable=True, scope=scope_bn)
bn_inference = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
is_training=False, reuse=True, trainable=True, scope=scope_bn)
z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
return z
def get_batch(self, Xi, Xv, y, batch_size, index):
start = index * batch_size
end = (index+1) * batch_size
end = end if end < len(y) else len(y)
return Xi[start:end], Xv[start:end], [[y_] for y_ in y[start:end]]
# shuffle three lists simutaneously
def shuffle_in_unison_scary(self, a, b, c):
rng_state = np.random.get_state()
np.random.shuffle(a)
np.random.set_state(rng_state)
np.random.shuffle(b)
np.random.set_state(rng_state)
np.random.shuffle(c)
def fit_on_batch(self, Xi, Xv, y):
feed_dict = {self.feat_index: Xi,
self.feat_value: Xv,
self.label: y,
self.dropout_keep_prob: self.drop_keep_prob,
self.train_phase: True}
step, loss, opt = self.sess.run((self.global_step, self.loss, self.optimizer), feed_dict=feed_dict)
return step, loss
# Since the train data is very large, they can not be fit into the memory at the same time.
# We separate the whole train data into several files and call "fit_once" for each file.
def fit_once(self, Xi_train, Xv_train, y_train,
epoch, file_count, Xi_valid=None,
Xv_valid=None, y_valid=None,
early_stopping=False):
has_valid = Xv_valid is not None
last_step = 0
t1 = time()
self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train)
total_batch = int(len(y_train) / self.batch_size)
for i in range(total_batch):
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, self.batch_size, i)
step, loss = self.fit_on_batch(Xi_batch, Xv_batch, y_batch)
last_step = step
# evaluate training and validation datasets
train_result, train_loss = self.evaluate(Xi_train, Xv_train, y_train)
self.train_result.append(train_result)
self.train_loss.append(train_loss)
if has_valid:
valid_result, valid_loss = self.evaluate(Xi_valid, Xv_valid, y_valid)
self.valid_result.append(valid_result)
self.valid_loss.append(valid_loss)
if valid_loss < self.best_loss and self.is_save == True:
old_loss = self.best_loss
self.best_loss = valid_loss
self.saver.save(self.sess, self.save_path + 'model.ckpt',global_step=last_step)
print("[%d-%d] model saved!. Valid loss is improved from %.4f to %.4f"
% (epoch, file_count, old_loss, self.best_loss))
if self.verbose > 0 and ((epoch-1)*9 + file_count) % self.verbose == 0:
if has_valid:
print("[%d-%d] train-result=%.4f, train-logloss=%.4f, valid-result=%.4f, valid-logloss=%.4f [%.1f s]" % (epoch, file_count, train_result, train_loss, valid_result, valid_loss, time() - t1))
else:
print("[%d-%d] train-result=%.4f [%.1f s]" \
% (epoch, file_count, train_result, time() - t1))
if has_valid and early_stopping and self.training_termination(self.valid_loss):
return False
else:
return True
def training_termination(self, valid_result):
if len(valid_result) > 5:
if self.greater_is_better:
if valid_result[-1] < valid_result[-2] and \
valid_result[-2] < valid_result[-3] and \
valid_result[-3] < valid_result[-4] and \
valid_result[-4] < valid_result[-5]:
return True
else:
if valid_result[-1] > valid_result[-2] and \
valid_result[-2] > valid_result[-3] and \
valid_result[-3] > valid_result[-4] and \
valid_result[-4] > valid_result[-5]:
return True
return False
def predict(self, Xi, Xv):
"""
:param Xi: list of list of feature indices of each sample in the dataset
:param Xv: list of list of feature values of each sample in the dataset
:return: predicted probability of each sample
"""
# dummy y
dummy_y = [1] * len(Xi)
batch_index = 0
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
y_pred = None
#y_loss = None
while len(Xi_batch) > 0:
num_batch = len(y_batch)
feed_dict = {self.feat_index: Xi_batch,
self.feat_value: Xv_batch,
self.label: y_batch,
self.dropout_keep_prob: [1.0] * len(self.drop_keep_prob),
self.train_phase: False}
batch_out = self.sess.run(self.out, feed_dict=feed_dict)
if batch_index == 0:
y_pred = np.reshape(batch_out, (num_batch,))
#y_loss = np.reshape(batch_loss, (num_batch,))
else:
y_pred = np.concatenate((y_pred, np.reshape(batch_out, (num_batch,))))
#y_loss = np.concatenate((y_loss, np.reshape(batch_loss, (num_batch,))))
batch_index += 1
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
return y_pred
def evaluate(self, Xi, Xv, y):
"""
:param Xi: list of list of feature indices of each sample in the dataset
:param Xv: list of list of feature values of each sample in the dataset
:param y: label of each sample in the dataset
:return: metric of the evaluation
"""
y_pred = self.predict(Xi, Xv)
y_pred = np.clip(y_pred,1e-6,1-1e-6)
return self.eval_metric(y, y_pred), log_loss(y, y_pred)
def restore(self, save_path=None):
if (save_path == None):
save_path = self.save_path
ckpt = tf.train.get_checkpoint_state(save_path)
if ckpt and ckpt.model_checkpoint_path:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
if self.verbose > 0:
print ("restored from %s" % (save_path))