-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_pcbm.py
125 lines (98 loc) · 5.78 KB
/
train_pcbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import os
import pickle
import numpy as np
import torch
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import roc_auc_score
from data import get_dataset
from concepts import ConceptBank
from models import PosthocLinearCBM, get_model
from training_tools import load_or_compute_projections
def config():
parser = argparse.ArgumentParser()
parser.add_argument("--concept-bank", required=True, type=str, help="Path to the concept bank")
parser.add_argument("--out-dir", required=True, type=str, help="Output folder for model/run info.")
parser.add_argument("--dataset", default="cub", type=str)
parser.add_argument("--backbone-name", default="resnet18_cub", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--seed", default=42, type=int, help="Random seed")
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--num-workers", default=4, type=int)
parser.add_argument("--alpha", default=0.99, type=float, help="Sparsity coefficient for elastic net.")
parser.add_argument("--lam", default=1e-5, type=float, help="Regularization strength.")
parser.add_argument("--lr", default=1e-3, type=float)
return parser.parse_args()
def run_linear_probe(args, train_data, test_data):
train_features, train_labels = train_data
test_features, test_labels = test_data
# We converged to using SGDClassifier.
# It's fine to use other modules here, this seemed like the most pedagogical option.
# We experimented with torch modules etc., and results are mostly parallel.
classifier = SGDClassifier(random_state=args.seed, loss="log_loss",
alpha=args.lam, l1_ratio=args.alpha, verbose=0,
penalty="elasticnet", max_iter=10000)
classifier.fit(train_features, train_labels)
train_predictions = classifier.predict(train_features)
train_accuracy = np.mean((train_labels == train_predictions).astype(float)) * 100.
predictions = classifier.predict(test_features)
test_accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
# Compute class-level accuracies. Can later be used to understand what classes are lacking some concepts.
cls_acc = {"train": {}, "test": {}}
for lbl in np.unique(train_labels):
test_lbl_mask = test_labels == lbl
train_lbl_mask = train_labels == lbl
cls_acc["test"][lbl] = np.mean((test_labels[test_lbl_mask] == predictions[test_lbl_mask]).astype(float))
cls_acc["train"][lbl] = np.mean(
(train_labels[train_lbl_mask] == train_predictions[train_lbl_mask]).astype(float))
print(f"{lbl}: {cls_acc['test'][lbl]}")
run_info = {"train_acc": train_accuracy, "test_acc": test_accuracy,
"cls_acc": cls_acc,
}
# If it's a binary task, we compute auc
if test_labels.max() == 1:
run_info["test_auc"] = roc_auc_score(test_labels, classifier.decision_function(test_features))
run_info["train_auc"] = roc_auc_score(train_labels, classifier.decision_function(train_features))
return run_info, classifier.coef_, classifier.intercept_
def main(args, concept_bank, backbone, preprocess):
train_loader, test_loader, idx_to_class, classes = get_dataset(args, preprocess)
# Get a clean conceptbank string
# e.g. if the path is /../../cub_resnet-cub_0.1_100.pkl, then the conceptbank string is resnet-cub_0.1_100
# which means a bank learned with 100 samples per concept with C=0.1 regularization parameter for the SVM.
# See `learn_concepts_dataset.py` for details.
conceptbank_source = args.concept_bank.split("/")[-1].split(".")[0]
num_classes = len(classes)
# Initialize the PCBM module.
posthoc_layer = PosthocLinearCBM(concept_bank, backbone_name=args.backbone_name, idx_to_class=idx_to_class, n_classes=num_classes)
posthoc_layer = posthoc_layer.to(args.device)
# We compute the projections and save to the output directory. This is to save time in tuning hparams / analyzing projections.
train_embs, train_projs, train_lbls, test_embs, test_projs, test_lbls = load_or_compute_projections(args, backbone, posthoc_layer, train_loader, test_loader)
run_info, weights, bias = run_linear_probe(args, (train_projs, train_lbls), (test_projs, test_lbls))
# Convert from the SGDClassifier module to PCBM module.
posthoc_layer.set_weights(weights=weights, bias=bias)
# Sorry for the model path hack. Probably i'll change this later.
model_path = os.path.join(args.out_dir,
f"pcbm_{args.dataset}__{args.backbone_name}__{conceptbank_source}__lam:{args.lam}__alpha:{args.alpha}__seed:{args.seed}.ckpt")
torch.save(posthoc_layer, model_path)
# Again, a sad hack.. Open to suggestions
run_info_file = model_path.replace("pcbm", "run_info-pcbm")
run_info_file = run_info_file.replace(".ckpt", ".pkl")
run_info_file = os.path.join(args.out_dir, run_info_file)
with open(run_info_file, "wb") as f:
pickle.dump(run_info, f)
if num_classes > 1:
# Prints the Top-5 Concept Weigths for each class.
print(posthoc_layer.analyze_classifier(k=5))
print(f"Model saved to : {model_path}")
print(run_info)
if __name__ == "__main__":
args = config()
all_concepts = pickle.load(open(args.concept_bank, 'rb'))
all_concept_names = list(all_concepts.keys())
print(f"Bank path: {args.concept_bank}. {len(all_concept_names)} concepts will be used.")
concept_bank = ConceptBank(all_concepts, args.device)
# Get the backbone from the model zoo.
backbone, preprocess = get_model(args, backbone_name=args.backbone_name)
backbone = backbone.to(args.device)
backbone.eval()
main(args, concept_bank, backbone, preprocess)