Skip to content

Commit

Permalink
dataset estimates
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanasandei committed Oct 7, 2024
1 parent 391d05a commit ca4910d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self, path: str, chunk_num: int, train: bool, device: str, dataset_
# used to reduce dataset size, for better training on low end devices
new_dataset_len = int(
len(self.frame_paths)*dataset_percentage/100.0)
print(
f"using {new_dataset_len}/{len(self.frame_paths)} frames for training")
# print(
# f"using {new_dataset_len}/{len(self.frame_paths)} frames for training")
self.frame_paths = self.frame_paths[:new_dataset_len]

# the parent dir of the parent dir
Expand Down
40 changes: 40 additions & 0 deletions src/dataset_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.utils.data import DataLoader
from tabulate import tabulate

from data.dataset import CommaDataset
from config import cfg

# params
batch_size = 4
dataset_percentage = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

# data
train_dataset = CommaDataset(
cfg["data"]["path"], chunk_num=1, train=True, device=device, dataset_percentage=dataset_percentage
)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)


def iters_to_time(iters: int, per_batch: float = 0.177) -> float:
# per_batch = 0.089 # pilotnet
sec_per_iter = per_batch*batch_size
sec = iters * sec_per_iter
hours = sec / 3600.0
return hours


one_epoch = len(train_dataloader)
epochs_num = 20
total_iters = one_epoch * epochs_num

# print results
print(tabulate([
[f"{dataset_percentage}% of dataset", len(train_dataloader)],
["batch size", batch_size],
[f"total epochs", epochs_num],
["time (hours)", iters_to_time(total_iters)],
], floatfmt=".2f"))

0 comments on commit ca4910d

Please sign in to comment.