-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
274 lines (212 loc) · 15.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from data.Cityscapes19 import Cityscapes19
from data.LIDC import LIDC
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from models.discriminators.GeneralDiscriminator import GeneralDiscriminator
from models.generators.GeneralGenerator import GeneralGenerator
from models.losses.GeneralLoss import GeneralLoss
from utils.general_utils import *
from utils.model_utils import find_model, load_models_and_state
from utils.constants import *
from utils.data_utils import get_transforms_list
import argparse
from training.train import TrainingProcess
from testing.test import evaluation
import torchvision.transforms as transforms
import torch
import sys
import os
import wandb
def load_data(dataset_name, mode, batch_size, args):
data = None
if args.class_flip:
assert args.dataset == "CITYSCAPES19" and LABELS_CHANNELS == 25, "To run on the modified Cityscapes dataset set LABELS_CHANNELS = 25"
if LABELS_CHANNELS == 25:
assert args.dataset == "CITYSCAPES19" and args.class_flip, "N_channels = 25 if CITYSCAPES19 with class flip is active"
# compose transform object
transforms_list = get_transforms_list(mode, args)
transform = transforms.Compose(transforms_list)
# initialize dataset
if dataset_name == "CITYSCAPES19":
dataset = Cityscapes19(mode, transform=transform) # initialize dataset
elif dataset_name == "LIDC":
dataset = LIDC(mode, transform=transform)
else:
raise ValueError("Unexpected dataset name.")
shuffle = mode == "train"
if mode == "train" or mode == "val" or mode == "test":
data = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, drop_last=True, pin_memory=True, num_workers=8)
else:
raise Exception(f"{mode} is not a valid dataset")
print(f"finished loading {len(data)} batches for {mode} dataset")
return data
def main(args):
if not args.debug:
# initialize logger
wandb.init(config=args)
if args.mode == "test":
dataloader_test = load_data(args.dataset, "val" if args.dataset == "CITYSCAPES19" else "test", args.batch_size_plotting, args) # CS results are reported on the val set
else:
dataloader_train = load_data(args.dataset, "train", args.batch_size, args)
dataloader_validate = load_data(args.dataset, "test" if args.dataset == "CITYSCAPES19" else "val", args.batch_size_plotting, args)
# we move the cities of Darmstadt, Mönchengladbach and Ulm from the training set to the test set to be used as a validation (as done in the Probabilisti UNet paper)
calibration_net = find_model(GEN_DIR, args.calibration_net,
device=DEVICE,
imsize=args.imsize,
n_channels_in=INPUT_CHANNELS if args.dataset != "CITYSCAPES19" else INPUT_CHANNELS + LABELS_CHANNELS,
n_channels_out=LABELS_CHANNELS,
use_dropout=args.dropout,
n_hidden=args.n_hidden_cal,
temperature = args.temperature).to(DEVICE)
generator = find_model(GEN_DIR, args.generator,
device=DEVICE,
imsize=args.imsize,
n_channels_in=INPUT_CHANNELS if args.calibration_net == "EmptyCalNet" else INPUT_CHANNELS + LABELS_CHANNELS,
n_channels_out=LABELS_CHANNELS,
z_dim = args.z_dim,
use_dropout=args.dropout,
n_hidden=args.n_hidden_gen).to(DEVICE)
discriminator = find_model(DIS_DIR, args.discriminator,
device=DEVICE,
imsize=args.imsize,
n_channels_in= LABELS_CHANNELS if args.discriminator == "AEDiscriminator" else INPUT_CHANNELS + LABELS_CHANNELS,
use_dropout=args.dropout,
n_hidden=args.n_hidden_dis).to(DEVICE)
# get models
if args.pretrained:
# load in state dicts
load_models_and_state(discriminator,
generator,
calibration_net,
args.models_to_load,
args.pretrained_model_suffix,
args.pretrained_model_date)
# assertions
assert_type(GeneralGenerator, generator)
assert_type(GeneralDiscriminator, discriminator)
assert_type(GeneralGenerator, calibration_net)
# train or test
if (args.mode == "train"):
# init optimizers
gen_optimizer = torch.optim.Adam(params=generator.parameters(),lr=args.learning_rate_gen, betas=args.opt_betas, weight_decay=args.opt_weight_decay)
dis_optimizer = torch.optim.Adam(params=discriminator.parameters(),lr=args.learning_rate_dis, betas=args.opt_betas, weight_decay=args.opt_weight_decay)
cal_optimizer = torch.optim.Adam(params=calibration_net.parameters(),lr=args.learning_rate_cal, betas=args.opt_betas, weight_decay=args.opt_weight_decay)
# init lr schedulers
gen_lr_scheduler = MultiStepLR(gen_optimizer, milestones=args.lr_step_gen[0], gamma=args.lr_step_gen[1])
dis_lr_scheduler = MultiStepLR(dis_optimizer, milestones=args.lr_step_dis[0], gamma=args.lr_step_dis[1])
cal_lr_scheduler = MultiStepLR(cal_optimizer, milestones=args.lr_step_cal[0], gamma=args.lr_step_cal[1])
# define loss functions
if (not args.loss_gen == TOTAL_G_LOSS):
print(f"{PRINTCOLOR_RED} WARNING: running with one generator-loss only: {args.loss_gen} {PRINTCOLOR_END}")
weights_loss_functions = get_loss_weights(args)
loss_gen = find_model(LOSS_DIR, TOTAL_G_LOSS, **weights_loss_functions)
loss_dis = find_model(LOSS_DIR, TOTAL_D_LOSS, **weights_loss_functions)
# assertions
assert_type(GeneralLoss, loss_dis)
assert_type(GeneralLoss, loss_gen)
# define process
train_progress = TrainingProcess(generator,
discriminator,
calibration_net,
dataloader_train,
dataloader_validate,
gen_optimizer,
dis_optimizer,
cal_optimizer,
gen_lr_scheduler,
dis_lr_scheduler,
cal_lr_scheduler,
loss_gen,
loss_dis,
args)
# train
trained_succesfully = train_progress.train()
elif (args.mode == "test"):
# load in state dicts
load_models_and_state(discriminator,
generator,
calibration_net,
args.test_models_to_load,
args.test_model_suffix,
args.test_model_date)
# run test
evaluation(dataloader_test, generator, calibration_net, discriminator, args, number_of_batches=20000, visualize=False, save=False, load = False)
else:
raise Exception(f"Unrecognized train/test mode?: {args.mode}")
def parse():
parser = argparse.ArgumentParser()
# Run mode ----------------------------------------------------------------------------------------------------------------
parser.add_argument('--mode', default="test", type=str, help="'train' or 'test'")
# debug
parser.add_argument('--debug', default=True, type=str, help="If false, does not save directories")
parser.add_argument('--timing', type=bool, default=False, help='are we measuring efficiency?')
# Pre-training arguments --------------------------------------------------------------------------------------------------
parser.add_argument('--pretrained', type=bool, default=False, help='Determines if we load a trained model or not')
parser.add_argument('--resume', type=bool, default=False, help='Determines if we resume training on the pre-trained model or not')
parser.add_argument('--pretrained_model_date', type=str, default="LIDC", help='date_stamp string for which model to load')
parser.add_argument('--pretrained_model_suffix', type=str, default="Best_Model", help='filename string for which model to load')
parser.add_argument('--models_to_load', type=list, default=["calibration_net"], help='list containing the name of models to load (["calibration_net", "generator", "discriminator"])')
# Training arguments ------------------------------------------------------------------------------------------------------
parser.add_argument('--epochs', default=1000000, type=int, help='max number of epochs')
parser.add_argument('--max_training_minutes', type=int, default=7197, help='After which process is killed automatically')
parser.add_argument('--eval_freq', type=int, default=10, help='Frequency (batch-wise) of evaluation')
parser.add_argument('--plot_freq', type=int, default=400, help='Frequency (batch-wise) of plotting pictures')
parser.add_argument('--saving_freq', type=int, default=50, help='Frequency (epoch-wise) of saving models')
parser.add_argument('--device', default="cuda", type=str, help='device')
parser.add_argument('--learning_rate_cal', type=float, default=2e-4, help='Learning rate')
parser.add_argument('--learning_rate_gen', type=float, default=2e-4, help='Learning rate')
parser.add_argument('--learning_rate_dis', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--lr_step_cal', type=float, default=([250,500], 0.5), help='Tuple enclosing list of epoch number where learning rate is changed and change factor')
parser.add_argument('--lr_step_gen', type=float, default=([100,300,500], 0.5), help='Tuple enclosing list of epoch number where learning rate is changed and change factor')
parser.add_argument('--lr_step_dis', type=float, default=([30,300,500], 0.5), help='Tuple enclosing list of epoch number where learning rate is changed and change factor')
parser.add_argument('--opt_betas', type=tuple, default=(0.5,0.99), help='beta1 and beta2 for ADAM optimization')
parser.add_argument('--opt_weight_decay', type=int, default=5e-4, help='weight decay for ADAM optimization')
parser.add_argument('--schedule', type=tuple, default=(50,200), help='Alternate discriminator training for specified intervals - (update, freeze)')
parser.add_argument('--DiscDAccCap', default=0.85, type=float, help="cap the discriminator accuracy at input value")
# Test arguments ----------------------------------------------------------------------------------------------------------------
parser.add_argument('--test_model_date', default="LIDC", type=str, help='date_stamp string for which model to load')
parser.add_argument('--test_model_suffix', default="LIDC_CAR_Model", type=str, help='filename string for which model to load')
parser.add_argument('--test_models_to_load', type=list, default=["calibration_net", "generator"])
# Model arguments ---------------------------------------------------------------------------------------------------------------
parser.add_argument('--calibration_net', default="SegNetCalNet", type=str, help="name of objectclass")
parser.add_argument('--generator', default="UNetGenerator", type=str, help="name of objectclass")
parser.add_argument('--discriminator', default="PixelDiscriminator", type=str, help="name of objectclass")
parser.add_argument('--n_generator_samples_test', default=16, type=int, help="no. of samples to be used in plots and stats")
parser.add_argument('--z_dim', default=8, type=int, help='dimensionality of latent code space')
parser.add_argument('--n_hidden_cal', type=int, default=64, help='features in the first hidden layer')
parser.add_argument('--n_hidden_gen', type=int, default=32, help='features in the first hidden layer')
parser.add_argument('--n_hidden_dis', type=int, default=64, help='features in the first hidden layer')
parser.add_argument('--temperature', type=torch.Tensor, default=torch.ones(LABELS_CHANNELS), help='specifies the magnitute of temperature scaling for the calibration net during test-time')
parser.add_argument('--dropout', type=bool, default=False, help='specifies whether to use dropout')
# Loss arguments ----------------------------------------------------------------------------------------------------------------
parser.add_argument('--loss_gen', default=TOTAL_G_LOSS, type=str, help="Overwrites hyperparams generatorloss if not total")
parser.add_argument('--loss_dis', default=TOTAL_D_LOSS, type=str, help="name of objectclass")
# hyperparams for Calibration net loss
parser.add_argument('--CalNetLoss_weight', default=1, type=float, help="weight hyperparameter for calibration net loss")
# hyperparams for Refinement net loss
parser.add_argument('--NonSaturatingGLoss_weight', default=10, type=float, help="weight hyperparameter for specific generatorloss")
parser.add_argument('--PixelLoss_weight', default=0, type=float, help="weight hyperparameter for specific generatorloss")
parser.add_argument('--CalLoss_weight', default=5, type=float, help="weight hyperparameter for specific generatorloss")
parser.add_argument('--ComplexityLoss_weight', default=0, type=float, help="weight hyperparameter for specific generatorloss")
parser.add_argument('--n_cal_samples', default=5, type=float, help="Number of samples to use for prediction average in cal loss")
# hyperparams for Discriminator loss
parser.add_argument('--DefaultDLoss_weight', default= 1, type=float, help="weight hyperparameter for specific discriminatorloss")
parser.add_argument('--label_smoothing', default=False, type=bool, help="specifies whether to use label smoothing or not")
# Data arguments -------------------------------------------------------------------------------------------------------------------
parser.add_argument('--batch-size', type=int, default=32, help='Size of batches loaded by the data loader.')
parser.add_argument('--batch-size-plotting', type=int, default=5, help='Size of validation batch')
parser.add_argument('--dataset', type=str, default='LIDC', help='LIDC, CITYSCAPES19 or CITYSCAPES35')
parser.add_argument('--class_flip', type=bool, default=False, help="Specifies whether to randomly flip classes in CITYSCAPES")
parser.add_argument('--flip_experiment', type = str, default = 'DEEP', help = "ROAD or DEEP") # flip only road or the 5 classes flipped in the ProbabilistcUNet paper
parser.add_argument('--crop', type=bool, default=True, help='Specifies whether to randomly crop dataset image or not')
parser.add_argument('--resize', type=bool, default=True, help='Specifies whether to resize dataset image size or not')
parser.add_argument('--noise', type=bool, default=False, help='Specifies whether to add random noise to labels')
parser.add_argument('--imsize', type=tuple, default= (IMSIZE[0], IMSIZE[1]) if parser.get_default('resize') else eval(f"{parser.get_default('dataset')}_IMSIZE"),
help='Returns tuple of selected dataset image dimentions (specified in constants')
assert len(parser.get_default('imsize')) == 2, f"Invalid IMSIZE values: {parser.get_default('imsize')}"
return parser.parse_args()
if __name__ == '__main__':
print("cuda_version:", torch.version.cuda, "pytorch version:", torch.__version__, "python version:", sys.version)
print("Working directory: ", os.getcwd())
args = parse()
main(args)