video2music / model /video_regression.py
kjysmu's picture
add files
4e46a55
import torch
import torch.nn as nn
from torch.nn.modules.normalization import LayerNorm
import random
import numpy as np
from utilities.constants import *
from utilities.device import get_device
from datetime import datetime
import torch.nn.functional as F
class VideoRegression(nn.Module):
def __init__(self, n_layers=2, d_model=64, dropout=0.1, max_sequence_video=300, total_vf_dim = 0, regModel="bilstm"):
super(VideoRegression, self).__init__()
self.nlayers = n_layers
self.d_model = d_model
self.dropout = dropout
self.max_seq_video = max_sequence_video
self.total_vf_dim = total_vf_dim
self.regModel = regModel
self.bilstm = nn.LSTM(self.total_vf_dim, self.d_model, self.nlayers, bidirectional=True)
self.bigru = nn.GRU(self.total_vf_dim, self.d_model, self.nlayers, bidirectional=True)
self.bifc = nn.Linear(self.d_model * 2, 2)
self.lstm = nn.LSTM(self.total_vf_dim, self.d_model, self.nlayers)
self.gru = nn.GRU(self.total_vf_dim, self.d_model, self.nlayers)
self.fc = nn.Linear(self.d_model, 2)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def forward(self, feature_semantic_list, feature_scene_offset, feature_motion, feature_emotion):
### Video (SemanticList + SceneOffset + Motion + Emotion) (ENCODER) ###
vf_concat = feature_semantic_list[0].float()
for i in range(1, len(feature_semantic_list)):
vf_concat = torch.cat( (vf_concat, feature_semantic_list[i].float()), dim=2)
vf_concat = torch.cat([vf_concat, feature_scene_offset.unsqueeze(-1).float()], dim=-1)
vf_concat = torch.cat([vf_concat, feature_motion.unsqueeze(-1).float()], dim=-1)
vf_concat = torch.cat([vf_concat, feature_emotion.float()], dim=-1)
vf_concat = vf_concat.permute(1,0,2)
vf_concat = F.dropout(vf_concat, p=self.dropout, training=self.training)
if self.regModel == "bilstm":
out, _ = self.bilstm(vf_concat)
out = out.permute(1,0,2)
out = self.bifc(out)
elif self.regModel == "bigru":
out, _ = self.bigru(vf_concat)
out = out.permute(1,0,2)
out = self.bifc(out)
elif self.regModel == "lstm":
out, _ = self.lstm(vf_concat)
out = out.permute(1,0,2)
out = self.fc(out)
elif self.regModel == "gru":
out, _ = self.gru(vf_concat)
out = out.permute(1,0,2)
out = self.fc(out)
return out