-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarizeModel.py
50 lines (37 loc) · 1.34 KB
/
summarizeModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import config
from models import models
def main() :
# Read Options
opt = config.readArguments(train=True)
# Create Model Instance
model = models.UIDCycleGAN(opt)
model = models.assignOnMultiGpus(opt, model)
# Create Dummy Input
dummyInput = torch.randn((1, opt.inputDim, opt.patchSize, opt.patchSize))
# Assign Device
if opt.gpuIds != "-1" :
dummyInput = dummyInput.cuda()
# Show Model Architecture
print(model.module.netGenS2T)
print(model.module.netGenT2S)
print(model.module.netDisS2T)
print(model.module.netDisT2S)
# Show Output
outputG = model.module.netGenS2T(dummyInput)
print("Feed-Forward Successful! (Gen. S2T)")
print(f"Output Size (G): {outputG.size()}")
# Show Output
outputG = model.module.netGenT2S(dummyInput)
print("Feed-Forward Successful! (Gen. T2S)")
print(f"Output Size (G): {outputG.size()}")
# Show Output
outputD = model.module.netDisS2T(dummyInput, dummyInput)
print("Feed-Forward Successful! (Dis. S2T)")
print(f"Output Size (D): {outputD.size()}")
# Show Output
outputD = model.module.netDisT2S(dummyInput, dummyInput)
print("Feed-Forward Successful! (Dis. T2S)")
print(f"Output Size (D): {outputD.size()}")
if __name__ == "__main__" :
main()