-
Notifications
You must be signed in to change notification settings - Fork 0
/
ace_helpers.py
440 lines (383 loc) · 16 KB
/
ace_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
""" collection of various helper functions for running ACE"""
from multiprocessing import dummy as multiprocessing
import os
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
from PIL import Image
from skimage.segmentation import mark_boundaries
from sklearn import linear_model
from sklearn.model_selection import cross_val_score
def load_image_from_file(filename, shape):
"""Given a filename, try to open the file. If failed, return None.
Args:
filename: location of the image file
shape: the shape of the image file to be scaled
Returns:
the image if succeeds, None if fails.
Rasies:
exception if the image was not the right shape.
"""
if not os.path.exists(filename):
print('Cannot find file: {}'.format(filename))
return None
try:
# print(f"Loading image from file: {filename}")
img = np.array(Image.open(filename).resize(shape, Image.BILINEAR))
img = np.float32(img) / 255.0
if not (len(img.shape) == 3 and img.shape[2] == 3):
return None
else:
return img
except Exception as e:
print(e)
return None
return img
def load_images_from_files(filenames, max_imgs=500, return_filenames=False,
do_shuffle=True, run_parallel=True,
shape=(299, 299),
num_workers=100
):
"""Return image arrays from filenames.
Args:
filenames: locations of image files.
max_imgs: maximum number of images from filenames.
return_filenames: return the succeeded filenames or not
do_shuffle: before getting max_imgs files, shuffle the names or not
run_parallel: get images in parallel or not
shape: desired shape of the image
num_workers: number of workers in parallelization.
Returns:
image arrays and succeeded filenames if return_filenames=True.
"""
imgs = []
# First shuffle a copy of the filenames.
filenames = filenames[:]
if do_shuffle:
np.random.shuffle(filenames)
if return_filenames:
final_filenames = []
if run_parallel:
pool = multiprocessing.Pool(num_workers)
imgs = pool.map(lambda filename: load_image_from_file(filename, shape),
filenames[:max_imgs])
if return_filenames:
final_filenames = [f for i, f in enumerate(filenames[:max_imgs])
if imgs[i] is not None]
imgs = [img for img in imgs if img is not None]
else:
for filename in filenames:
img = load_image_from_file(filename, shape)
if img is not None:
imgs.append(img)
if return_filenames:
final_filenames.append(filename)
if len(imgs) >= max_imgs:
break
if return_filenames:
return np.array(imgs), final_filenames
else:
return np.array(imgs)
def flat_profile(cd, images, bottlenecks=None):
"""Returns concept profile of given images.
Given a ConceptDiscovery class instance and a set of images, and desired
bottleneck layers, calculates the profile of each image with all concepts and
returns a profile vector
Args:
cd: The concept discovery class instance
images: The images for which the concept profile is calculated
bottlenecks: Bottleck layers where the profile is calculated. If None, cd
bottlenecks will be used.
Returns:
The concepts profile of input images using discovered concepts in
all bottleneck layers.
Raises:
ValueError: If bottlenecks is not in right format.
"""
profiles = []
if bottlenecks is None:
bottlenecks = list(cd.dic.keys())
if isinstance(bottlenecks, str):
bottlenecks = [bottlenecks]
elif not isinstance(bottlenecks, list) and not isinstance(bottlenecks, tuple):
raise ValueError('Invalid bottlenecks parameter!')
for bn in bottlenecks:
profiles.append(cd.find_profile(str(bn), images).reshape((len(images), -1)))
profile = np.concatenate(profiles, -1)
return profile
def cross_val(a, b, methods):
"""Performs cross validation for a binary classification task.
Args:
a: First class data points as rows
b: Second class data points as rows
methods: The sklearn classification models to perform cross-validation on
Returns:
The best performing trained binary classification odel
"""
x, y = binary_dataset(a, b)
best_acc = 0.
if isinstance(methods, str):
methods = [methods]
best_acc = 0.
for method in methods:
temp_acc = 0.
params = [10 ** e for e in [-4, -3, -2, -1, 0, 1, 2, 3]]
for param in params:
clf = give_classifier(method, param)
acc = cross_val_score(clf, x, y, cv=min(100, max(2, int(len(y) / 10))))
if np.mean(acc) > temp_acc:
temp_acc = np.mean(acc)
best_param = param
if temp_acc > best_acc:
best_acc = temp_acc
final_clf = give_classifier(method, best_param)
final_clf.fit(x, y)
return final_clf, best_acc
def give_classifier(method, param):
"""Returns an sklearn classification model.
Args:
method: Name of the sklearn classification model
param: Hyperparameters of the sklearn model
Returns:
An untrained sklearn classification model
Raises:
ValueError: if the model name is invalid.
"""
if method == 'logistic':
return linear_model.LogisticRegression(C=param)
elif method == 'sgd':
return linear_model.SGDClassifier(alpha=param)
else:
raise ValueError('Invalid model!')
def binary_dataset(pos, neg, balanced=True):
"""Creates a binary dataset given instances of two classes.
Args:
pos: Data points of the first class as rows
neg: Data points of the second class as rows
balanced: If true, it creates a balanced binary dataset.
Returns:
The data points of the created data set as rows and the corresponding labels
"""
if balanced:
min_len = min(neg.shape[0], pos.shape[0])
ridxs = np.random.permutation(np.arange(2 * min_len))
x = np.concatenate([neg[:min_len], pos[:min_len]], 0)[ridxs]
y = np.concatenate([np.zeros(min_len), np.ones(min_len)], 0)[ridxs]
else:
ridxs = np.random.permutation(np.arange(len(neg) + len(pos)))
x = np.concatenate([neg, pos], 0)[ridxs]
y = np.concatenate(
[np.zeros(neg.shape[0]), np.ones(pos.shape[0])], 0)[ridxs]
return x, y
def plot_concepts(cd, bn, num=10, address=None, mode='diverse', concepts=None):
"""Plots examples of discovered concepts.
Args:
cd: The concept discovery instance
bn: Bottleneck layer name
num: Number of images to print out of each concept
address: If not None, saves the output to the address as a .PNG image
mode: If 'diverse', it prints one example of each of the target class images
is coming from. If 'radnom', randomly samples exmples of the concept. If
'max', prints out the most activating examples of that concept.
concepts: If None, prints out examples of all discovered concepts.
Otherwise, it should be either a list of concepts to print out examples of
or just one concept's name
Raises:
ValueError: If the mode is invalid.
"""
if concepts is None:
concepts = cd.dic[bn]['concepts']
elif not isinstance(concepts, list) and not isinstance(concepts, tuple):
concepts = [concepts]
num_concepts = len(concepts)
plt.rcParams['figure.figsize'] = num * 2.1, 4.3 * num_concepts
fig = plt.figure(figsize=(num * 2, 4 * num_concepts))
outer = gridspec.GridSpec(num_concepts, 1, wspace=0., hspace=0.3)
for n, concept in enumerate(concepts):
inner = gridspec.GridSpecFromSubplotSpec(
2, num, subplot_spec=outer[n], wspace=0, hspace=0.1)
concept_images = cd.dic[bn][concept]['images']
concept_patches = cd.dic[bn][concept]['patches']
concept_image_numbers = cd.dic[bn][concept]['image_numbers']
if mode == 'max':
idxs = np.arange(len(concept_images))
elif mode == 'random':
idxs = np.random.permutation(np.arange(len(concept_images)))
elif mode == 'diverse':
idxs = []
while True:
seen = set()
for idx in range(len(concept_images)):
if concept_image_numbers[idx] not in seen and idx not in idxs:
seen.add(concept_image_numbers[idx])
idxs.append(idx)
if len(idxs) == len(concept_images):
break
else:
raise ValueError('Invalid mode!')
idxs = idxs[:num]
for i, idx in enumerate(idxs):
ax = plt.Subplot(fig, inner[i])
ax.imshow(concept_images[idx])
ax.set_xticks([])
ax.set_yticks([])
if i == int(num / 2):
ax.set_title(concept)
ax.grid(False)
fig.add_subplot(ax)
ax = plt.Subplot(fig, inner[i + num])
mask = 1 - (np.mean(concept_patches[idx] == float(
cd.average_image_value) / 255, -1) == 1)
image = cd.discovery_images[concept_image_numbers[idx]]
ax.imshow(mark_boundaries(image, mask, color=(1, 1, 0), mode='thick'))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(str(concept_image_numbers[idx]))
ax.grid(False)
fig.add_subplot(ax)
plt.suptitle(bn)
if address is not None:
# tf
# with tf.gfile.Open(address + bn + '.png', 'w') as f:
# fig.savefig(f)
# new
# with tf.gfile.Open(address + bn + '.png', 'w') as f:
fig.savefig(address + bn + '.png')
plt.clf()
plt.close(fig)
def cosine_similarity(a, b):
"""Cosine similarity of two vectors."""
assert a.shape == b.shape, 'Two vectors must have the same dimensionality'
a_norm, b_norm = np.linalg.norm(a), np.linalg.norm(b)
if a_norm * b_norm == 0:
return 0.
cos_sim = np.sum(a * b) / (a_norm * b_norm)
return cos_sim
def similarity(cd, num_random_exp=None, num_workers=25):
"""Returns cosine similarity of all discovered concepts.
Args:
cd: The ConceptDiscovery module for discovered conceps.
num_random_exp: If None, calculates average similarity using all the class's
random concepts. If a number, uses that many random counterparts.
num_workers: If greater than 0, runs the function in parallel.
Returns:
A similarity dict in the form of {(concept1, concept2):[list of cosine
similarities]}
"""
def concepts_similarity(cd, concepts, rnd, bn):
"""Calcualtes the cosine similarity of concept cavs.
This function calculates the pairwise cosine similarity of all concept cavs
versus an specific random concept
Args:
cd: The ConceptDiscovery instance
concepts: List of concepts to calculate similarity for
rnd: a random counterpart
bn: bottleneck layer the concepts belong to
Returns:
A dictionary of cosine similarities in the form of
{(concept1, concept2): [list of cosine similarities], ...}
"""
similarity_dic = {}
for c1 in concepts:
cav1 = cd.load_cav_direction(c1, rnd, bn)
for c2 in concepts:
if (c1, c2) in similarity_dic.keys():
continue
cav2 = cd.load_cav_direction(c2, rnd, bn)
similarity_dic[(c1, c2)] = cosine_similarity(cav1, cav2)
similarity_dic[(c2, c1)] = similarity_dic[(c1, c2)]
return similarity_dic
similarity_dic = {bn: {} for bn in cd.bottlenecks}
if num_random_exp is None:
num_random_exp = cd.num_random_exp
randoms = ['random500_{}'.format(i) for i in np.arange(num_random_exp)]
concepts = {}
for bn in cd.bottlenecks:
concepts[bn] = [cd.target_class, cd.random_concept] + cd.dic[bn]['concepts']
for bn in cd.bottlenecks:
concept_pairs = [(c1, c2) for c1 in concepts[bn] for c2 in concepts[bn]]
similarity_dic[bn] = {pair: [] for pair in concept_pairs}
def t_func(rnd):
return concepts_similarity(cd, concepts[bn], rnd, bn)
if num_workers:
pool = multiprocessing.Pool(num_workers)
sims = pool.map(lambda rnd: t_func(rnd), randoms)
else:
sims = [t_func(rnd) for rnd in randoms]
while sims:
sim = sims.pop()
for pair in concept_pairs:
similarity_dic[bn][pair].append(sim[pair])
return similarity_dic
def save_ace_report(cd, accs, scores, address):
"""Saves TCAV scores.
Saves the average CAV accuracies and average TCAV scores of the concepts
discovered in ConceptDiscovery instance.
Args:
cd: The ConceptDiscovery instance.
accs: The cav accuracy dictionary returned by cavs method of the
ConceptDiscovery instance
scores: The tcav score dictionary returned by tcavs method of the
ConceptDiscovery instance
address: The address to save the text file in.
"""
report = '\n\n\t\t\t ---CAV accuracies---'
for bn in cd.bottlenecks:
report += '\n'
for concept in cd.dic[bn]['concepts']:
report += '\n' + bn + ':' + concept + ':' + str(
np.mean(accs[bn][concept]))
with open(address + 'CAV_ace_results.txt', 'w') as f:
f.write(report)
report = '\n\n\t\t\t ---TCAV scores---'
for bn in cd.bottlenecks:
report += '\n'
for concept in cd.dic[bn]['concepts']:
pvalue = cd.do_statistical_testings(
scores[bn][concept], scores[bn][cd.random_concept])
report += '\n{}:{}:{},{}'.format(bn, concept,
np.mean(scores[bn][concept]), pvalue)
with open(address + 'TCAV_ace_results.txt', 'w') as f:
f.write(report)
def save_concepts(cd, concepts_dir):
"""Saves discovered concept's images or patches.
Args:
cd: The ConceptDiscovery instance the concepts of which we want to save
concepts_dir: The directory to save the concept images
"""
for bn in cd.bottlenecks:
for concept in cd.dic[bn]['concepts']:
patches_dir = os.path.join(concepts_dir, bn + '_' + concept + '_patches')
images_dir = os.path.join(concepts_dir, bn + '_' + concept)
patches = (np.clip(cd.dic[bn][concept]['patches'], 0, 1) * 256).astype(
np.uint8)
images = (np.clip(cd.dic[bn][concept]['images'], 0, 1) * 256).astype(
np.uint8)
os.makedirs(patches_dir)
os.makedirs(images_dir)
image_numbers = cd.dic[bn][concept]['image_numbers']
image_addresses, patch_addresses = [], []
for i in range(len(images)):
image_name = '0' * int(np.ceil(2 - np.log10(i + 1))) + '{}_{}'.format(
i + 1, image_numbers[i])
patch_addresses.append(os.path.join(patches_dir, image_name + '.png'))
image_addresses.append(os.path.join(images_dir, image_name + '.png'))
save_images(patch_addresses, patches)
save_images(image_addresses, images)
def save_images(addresses, images):
"""Save images in the addresses.
Args:
addresses: The list of addresses to save the images as or the address of the
directory to save all images in. (list or str)
images: The list of all images in numpy uint8 format.
"""
if not isinstance(addresses, list):
image_addresses = []
for i, image in enumerate(images):
image_name = '0' * (3 - int(np.log10(i + 1))) + str(i + 1) + '.png'
image_addresses.append(os.path.join(addresses, image_name))
addresses = image_addresses
assert len(addresses) == len(images), 'Invalid number of addresses'
for address, image in zip(addresses, images):
# with tf.gfile.Open(address, 'w') as f:
Image.fromarray(image).save(address, format='PNG')