From ca4910dd60ad1a63e351a9f537ab0f696a537dc1 Mon Sep 17 00:00:00 2001 From: Stefan Asandei Date: Mon, 7 Oct 2024 21:54:51 +0300 Subject: [PATCH] dataset estimates --- src/data/dataset.py | 4 ++-- src/dataset_params.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 src/dataset_params.py diff --git a/src/data/dataset.py b/src/data/dataset.py index cba1d45..972674c 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -36,8 +36,8 @@ def __init__(self, path: str, chunk_num: int, train: bool, device: str, dataset_ # used to reduce dataset size, for better training on low end devices new_dataset_len = int( len(self.frame_paths)*dataset_percentage/100.0) - print( - f"using {new_dataset_len}/{len(self.frame_paths)} frames for training") + # print( + # f"using {new_dataset_len}/{len(self.frame_paths)} frames for training") self.frame_paths = self.frame_paths[:new_dataset_len] # the parent dir of the parent dir diff --git a/src/dataset_params.py b/src/dataset_params.py new file mode 100644 index 0000000..7df5a29 --- /dev/null +++ b/src/dataset_params.py @@ -0,0 +1,40 @@ +import torch +from torch.utils.data import DataLoader +from tabulate import tabulate + +from data.dataset import CommaDataset +from config import cfg + +# params +batch_size = 4 +dataset_percentage = 10 + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# data +train_dataset = CommaDataset( + cfg["data"]["path"], chunk_num=1, train=True, device=device, dataset_percentage=dataset_percentage +) +train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4) + + +def iters_to_time(iters: int, per_batch: float = 0.177) -> float: + # per_batch = 0.089 # pilotnet + sec_per_iter = per_batch*batch_size + sec = iters * sec_per_iter + hours = sec / 3600.0 + return hours + + +one_epoch = len(train_dataloader) +epochs_num = 20 +total_iters = one_epoch * epochs_num + +# print results +print(tabulate([ + [f"{dataset_percentage}% of dataset", len(train_dataloader)], + ["batch size", batch_size], + [f"total epochs", epochs_num], + ["time (hours)", iters_to_time(total_iters)], +], floatfmt=".2f"))