Skip to content

Commit

Permalink
info about flops
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanasandei committed Sep 22, 2024
1 parent 75c3710 commit 46dc8bd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
33 changes: 28 additions & 5 deletions src/info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from torchinfo import summary
from modules.model import SteerNetWrapped, Seq2SeqWrapped, PilotNetWrapped
import torch
import torch.nn as nn
import argparse
from torchinfo import summary
from thop import profile


def compute_summary(model: nn.Module):
B, T, HW = 1, 11, 224
sample_frames = torch.randn((B, T, 3, HW, HW), device="cuda")
sample_paths = torch.randn((B, T, 3), device="cuda")

flops, _ = profile(model, inputs=(
sample_frames, sample_paths), verbose=False)
print("flops: " + str(flops))

from modules.model import SteerNetWrapped, Seq2SeqWrapped, PilotNetWrapped

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -10,12 +23,22 @@
required=True)
args = parser.parse_args()

depth = 3
model = None

if args.model == "steer":
print(summary(SteerNetWrapped("cuda")))
depth = 6 # for detailed Mamba layers overview
model = SteerNetWrapped("cuda", compile=False)
elif args.model == "pilotnet":
print(summary(PilotNetWrapped("cuda")))
model = PilotNetWrapped("cuda", compile=False)
elif args.model == "seq2seq":
print(summary(Seq2SeqWrapped("cuda")))
model = Seq2SeqWrapped("cuda", compile=False)

# param count & layers
summary(model, depth=depth)

# tlfops & compute time
compute_summary(model)

# steer: 6.3M
# seq2seq 5.9M (about ~5M is from RegNet)
Expand Down
24 changes: 15 additions & 9 deletions src/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,51 @@
from config import cfg


def PilotNetWrapped(device: str, return_dict: bool = True) -> nn.Module:
def PilotNetWrapped(device: str, return_dict: bool = True, compile: bool = True) -> nn.Module:
model = PilotNet().to(device)

if return_dict:
model.compile()
if compile:
model.compile()
model = AVWrapper(model, return_dict=True).to(device)
# can't compile a model returning a dict
else:
model = AVWrapper(model, return_dict=False).to(device)
model.compile()
if compile:
model.compile()

return model


def Seq2SeqWrapped(device: str, return_dict: bool = True) -> nn.Module:
def Seq2SeqWrapped(device: str, return_dict: bool = True, compile: bool = True) -> nn.Module:
# add one more frame, the current one
model = Seq2Seq().to(device)

if return_dict:
model.compile()
if compile:
model.compile()
model = AVWrapper(model, return_dict=True).to(device)
# can't compile a model returning a dict
else:
model = AVWrapper(model, return_dict=False).to(device)
model.compile()
if compile:
model.compile()

return model


def SteerNetWrapped(device: str, return_dict: bool = True) -> nn.Module:
def SteerNetWrapped(device: str, return_dict: bool = True, compile: bool = True) -> nn.Module:
# add one more frame, the current one
model = SteerNet(cfg["model"]["past_steps"]+1).to(device)

if return_dict:
model.compile()
if compile:
model.compile()
model = AVWrapper(model, return_dict=True).to(device)
# can't compile a model returning a dict
else:
model = AVWrapper(model, return_dict=False).to(device)
model.compile()
if compile:
model.compile()

return model

0 comments on commit 46dc8bd

Please sign in to comment.