-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
121 lines (96 loc) · 3.37 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
"""
Using the plotter:
Call it from the command line, and supply it with logdirs to experiments.
Suppose you ran an experiment with name 'test', and you ran 'test' for 10
random seeds. The runner code stored it in the directory structure
data
L test_EnvName_DateTime
L 0
L log.txt
L params.json
L 1
L log.txt
L params.json
.
.
.
L 9
L log.txt
L params.json
To plot learning curves from the experiment, averaged over all random
seeds, call
python plot.py data/test_EnvName_DateTime --value AverageReturn
and voila. To see a different statistics, change what you put in for
the keyword --value. You can also enter /multiple/ values, and it will
make all of them in order.
Suppose you ran two experiments: 'test1' and 'test2'. In 'test2' you tried
a different set of hyperparameters from 'test1', and now you would like
to compare them -- see their learning curves side-by-side. Just call
python plot.py data/test1 data/test2
and it will plot them both! They will be given titles in the legend according
to their exp_name parameters. If you want to use custom legend titles, use
the --legend flag and then provide a title for each logdir.
"""
def plot_data(data, value="AverageReturn"):
if isinstance(data, list):
data = pd.concat(data, ignore_index=True)
sns.set(style="darkgrid", font_scale=1.5)
sns.tsplot(data=data, time="Iteration", value=value, unit="Unit", condition="Condition")
plt.legend(loc='best').draggable()
plt.tight_layout()
plt.show()
def get_datasets(fpath, condition=None):
unit = 0
datasets = []
for root, dir, files in os.walk(fpath):
if 'log.txt' in files:
param_path = open(os.path.join(root,'params.json'))
params = json.load(param_path)
exp_name = params['exp_name']
log_path = os.path.join(root,'log.txt')
experiment_data = pd.read_table(log_path)
experiment_data.insert(
len(experiment_data.columns),
'Unit',
unit
)
experiment_data.insert(
len(experiment_data.columns),
'Condition',
condition or exp_name
)
datasets.append(experiment_data)
unit += 1
return datasets
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('logdir', nargs='*')
parser.add_argument('--legend', nargs='*')
parser.add_argument('--value', default='AverageReturn', nargs='*')
args = parser.parse_args()
use_legend = False
if args.legend is not None:
assert len(args.legend) == len(args.logdir), \
"Must give a legend title for each set of experiments."
use_legend = True
data = []
if use_legend:
for logdir, legend_title in zip(args.logdir, args.legend):
data += get_datasets(logdir, legend_title)
else:
for logdir in args.logdir:
data += get_datasets(logdir)
if isinstance(args.value, list):
values = args.value
else:
values = [args.value]
for value in values:
plot_data(data, value=value)
if __name__ == "__main__":
main()