-
Notifications
You must be signed in to change notification settings - Fork 10
/
learn_concepts_multimodal.py
176 lines (149 loc) · 6.94 KB
/
learn_concepts_multimodal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import requests
import os
import pickle
import torch
import clip
import argparse
import numpy as np
from tqdm import tqdm
def config():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", required=True, type=str)
parser.add_argument("--classes", default="cifar10", type=str)
parser.add_argument("--backbone-name", default="clip:RN50", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--recurse", default=1, type=int, help="How many times to recurse on the conceptnet graph")
return parser.parse_args()
def get_single_concept_data(cls_name):
if cls_name in concept_cache:
return concept_cache[cls_name]
all_concepts = []
# Has relations
has_query = "https://api.conceptnet.io/query?node=/c/en/{}&rel=/r/HasA&start=/c/en/{}"
obj = requests.get(has_query.format(cls_name, cls_name)).json()
for edge in obj["edges"]:
all_concepts.append(edge['end']['label'])
# Made of relations
madeof_query = "https://api.conceptnet.io/query?node=/c/en/{}&rel=/r/MadeOf&start=/c/en/{}"
obj = requests.get(madeof_query.format(cls_name, cls_name)).json()
for edge in obj["edges"]:
all_concepts.append(edge['end']['label'])
# Properties of things
property_query = "https://api.conceptnet.io/query?node=/c/en/{}&rel=/r/HasProperty&start=/c/en/{}"
obj = requests.get(property_query.format(cls_name, cls_name)).json()
for edge in obj["edges"]:
all_concepts.append(edge['end']['label'])
# Categorization concepts
is_query = "https://api.conceptnet.io/query?node=/c/en/{}&rel=/r/IsA&start=/c/en/{}"
obj = requests.get(is_query.format(cls_name, cls_name)).json()
for edge in obj["edges"]:
if edge["weight"] <= 1:
continue
all_concepts.append(edge['end']['label'])
# Parts of things
parts_query = "https://api.conceptnet.io/query?node=/c/en/{}&rel=/r/PartOf&end=/c/en/{}"
obj = requests.get(parts_query.format(cls_name, cls_name)).json()
for edge in obj["edges"]:
all_concepts.append(edge['start']['label'])
all_concepts = [c.lower() for c in all_concepts]
# Drop the "a " for concepts defined like "a {concept}".
all_concepts = [c.replace("a ", "") for c in all_concepts]
# Drop all empty concepts.
all_concepts = [c for c in all_concepts if c!=""]
# Make each concept unique in the set.
all_concepts = set(all_concepts)
concept_cache[cls_name] = all_concepts
return all_concepts
def get_concept_data(all_classes):
all_concepts = set()
# Collect concepts that are relevant to each class
for cls_name in all_classes:
print(f"Pulling concepts for {cls_name}")
all_concepts |= get_single_concept_data(cls_name)
return all_concepts
def clean_concepts(scenario_concepts):
"""
Clean the plurals, trailing whitespaces etc.
"""
from nltk.stem.wordnet import WordNetLemmatizer
import nltk
# We use nltk to handle plurals, multiples of the same words etc.
nltk.download("wordnet")
nltk.download("omw-1.4")
Lem = WordNetLemmatizer()
scenario_concepts_rec = []
for c_prev in scenario_concepts:
c = c_prev
c = c.strip()
c_subwords = c.split(" ")
# If a concept is made of more than 2 words, we drop it.
if len(c_subwords) > 2:
print("skipping long concept", c_prev)
continue
# Lemmatize words to help eliminate non-unique concepts etc.
for i, csw in enumerate(c_subwords):
c_subwords[i] = Lem.lemmatize(csw)
lemword = " ".join(c_subwords)
if c_prev == lemword:
scenario_concepts_rec.append(c)
else:
if lemword in scenario_concepts:
print(c, lemword)
else:
scenario_concepts_rec.append(c)
scenario_concepts_rec = list(set(scenario_concepts_rec))
return scenario_concepts_rec
@torch.no_grad()
def learn_conceptbank(args, concept_list, scenario):
concept_dict = {}
for concept in tqdm(concept_list):
# Note: You can try other forms of prompting, e.g. "photo of {concept}" etc. here.
text = clip.tokenize(f"{concept}").to("cuda")
text_features = model.encode_text(text).cpu().numpy()
text_features = text_features/np.linalg.norm(text_features)
# store concept vectors in a dictionary. Adding the additional terms to be consistent with the
# `ConceptBank` class (see `concepts/concept_utils.py`).
concept_dict[concept] = (text_features, None, None, 0, {})
print(f"# concepts: {len(concept_dict)}")
concept_dict_path = os.path.join(args.out_dir, f"multimodal_concept_{args.backbone_name}_{scenario}_recurse:{args.recurse}.pkl")
pickle.dump(concept_dict, open(concept_dict_path, 'wb'))
print(f"Dumped to : {concept_dict_path}")
if __name__ == "__main__":
args = config()
model, _ = clip.load(args.backbone_name.split(":")[1], device=args.device, download_root=args.out_dir)
concept_cache = {}
if args.classes == "cifar10":
# Pull CIFAR10 to get the class names.
from torchvision import datasets
cifar10_ds = datasets.CIFAR10(root=args.out_dir, train=True, download=True)
# Get the class names.
all_classes = list(cifar10_ds.classes)
# Get the names of all concepts.
all_concepts = get_concept_data(all_classes)
# Clean the concepts for uniques, plurals etc.
all_concepts = clean_concepts(all_concepts)
all_concepts = list(set(all_concepts).difference(set(all_classes)))
# If we'd like to recurse in the conceptnet graph, specify `recurse > 1`.
for i in range(1, args.recurse):
all_concepts = get_concept_data(all_concepts)
all_concepts = list(set(all_concepts))
all_concepts = clean_concepts(all_concepts)
all_concepts = list(set(all_concepts).difference(set(all_classes)))
# Generate the concept bank.
learn_conceptbank(args, all_concepts, args.classes)
elif args.classes == "cifar100":
from torchvision import datasets
cifar100_ds = datasets.CIFAR100(root=args.out_dir, train=True, download=True)
all_classes = list(cifar100_ds.classes)
all_concepts = get_concept_data(all_classes)
all_concepts = clean_concepts(all_concepts)
all_concepts = list(set(all_concepts).difference(set(all_classes)))
# If we'd like to recurse in the conceptnet graph, specify `recurse > 1`.
for i in range(1, args.recurse):
all_concepts = get_concept_data(all_concepts)
all_concepts = list(set(all_concepts))
all_concepts = clean_concepts(all_concepts)
all_concepts = list(set(all_concepts).difference(set(all_classes)))
learn_conceptbank(args, all_concepts, args.classes)
else:
raise ValueError(f"Unknown classes: {args.classes}. Define your dataset here!")