AnsenH's picture
feat: add our model
24615d9
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScoringModel(nn.Module):
def __init__(self, frames_per_clip:int, input_dim: int, hidden_dim: int, num_hidden_layers: int):
super().__init__()
self.frames_per_clip = frames_per_clip
if num_hidden_layers == 0:
self.model = nn.Linear(input_dim * frames_per_clip, 1)
else:
modules = [
nn.Linear(input_dim * frames_per_clip, hidden_dim),
nn.ReLU(True),
nn.BatchNorm1d(hidden_dim)
]
for _ in range(num_hidden_layers-1):
modules.extend([
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.BatchNorm1d(hidden_dim)
])
modules.append(nn.Linear(hidden_dim, 1))
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
if __name__ == '__main__':
batch_size, input_dim, frames_per_clip = 8, 512, 3
hidden_dim, num_hidden_layers = 0, 0
x = torch.rand(batch_size, input_dim * frames_per_clip)
scoring_model = ScoringModel(frames_per_clip, input_dim, hidden_dim, num_hidden_layers)
print(scoring_model)
y = scoring_model(x)
print(y.size()) # should be (batch_size, 1)