Skip to content

Commit

Permalink
more shape explications
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanasandei committed Sep 28, 2024
1 parent a0cc0d5 commit 391d05a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
14 changes: 2 additions & 12 deletions notebooks/dataset.ipynb

Large diffs are not rendered by default.

Binary file added notebooks/frame_orientation.pdf
Binary file not shown.
Binary file added notebooks/speed_sensors.pdf
Binary file not shown.
11 changes: 9 additions & 2 deletions src/modules/steer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def forward(self, x):
x = torch.cat((cls_token, x), dim=1)
# basically (B*T, H*W+1, C); add a prefix token

# (120, 197, 192); (197, 192); (1, 30, 192)
# (B*T, W*H+1, C); (W*H+1, C); (1, T, C)
# (x.shape, self.pos_embd.shape, self.temp_embd.shape)

# 2. positional embedding
# same shape: (B*T, H*W+1, C)
x = x + self.pos_embd # just add the spatial info
Expand All @@ -191,6 +195,9 @@ def forward(self, x):
x = torch.cat((cls_tokens, x), dim=1)
# (B, T', C); new linear sequence of embedding tokens

# T' = 5881 = 30 * (224/16) * (224/16) + 1 (cls token)
# = T * (H/P) * (W/P); and C = 192; embedding size

x = self.drop(x)
return x

Expand Down Expand Up @@ -306,8 +313,8 @@ def forward(self, hidden):

# let's test the model
if __name__ == "__main__":
B, T, HW = 8, 30, 224
past_frames = torch.randn((B, 3, T, HW, HW), device="cuda")
B, T, HW = 4, 30, 224
past_frames = torch.randn((B, T, 3, HW, HW), device="cuda")
past_xyz = torch.randn((B, T, 3), device="cuda")

model = SteerNet(n_frames=T, img_size=HW).to("cuda")
Expand Down

0 comments on commit 391d05a

Please sign in to comment.